From 286d5e7859d4020d41ee2a4ade4c4f7749d69c79 Mon Sep 17 00:00:00 2001 From: Candice Moyet Date: Thu, 6 Jul 2023 09:58:30 +0200 Subject: [PATCH 01/30] ENH: refacto MapieRegressor and ConformityScore --- mapie/conformity_scores/conformity_scores.py | 210 +++++++- .../residual_conformity_scores.py | 39 +- mapie/regression/estimator.py | 504 ++++++++++++++++++ mapie/regression/regression.py | 382 +++---------- 4 files changed, 791 insertions(+), 344 deletions(-) create mode 100644 mapie/regression/estimator.py diff --git a/mapie/conformity_scores/conformity_scores.py b/mapie/conformity_scores/conformity_scores.py index c5f68974..87976c76 100644 --- a/mapie/conformity_scores/conformity_scores.py +++ b/mapie/conformity_scores/conformity_scores.py @@ -1,8 +1,11 @@ from abc import ABCMeta, abstractmethod import numpy as np +from typing import Tuple +from mapie._compatibility import np_nanquantile from mapie._typing import ArrayLike, NDArray +from sklearn.base import RegressorMixin class ConformityScore(metaclass=ABCMeta): @@ -31,7 +34,9 @@ class ConformityScore(metaclass=ABCMeta): - ``get_signed_conformity_scores`` The following equality must be verified: ``self.get_estimation_distribution( - y_pred, self.get_conformity_scores(y, y_pred) + X, + y_pred, + self.get_conformity_scores(X, y, y_pred) ) == y`` It should be specified if ``consistency_check==True``. @@ -51,6 +56,7 @@ def __init__( @abstractmethod def get_signed_conformity_scores( self, + X: ArrayLike, y: ArrayLike, y_pred: ArrayLike, ) -> NDArray: @@ -63,47 +69,57 @@ def get_signed_conformity_scores( Parameters ---------- - y: NDArray - Observed values. + X: ArrayLike + Observed feature values. - y_pred: NDArray - Predicted values. + y: ArrayLike + Observed target values. + + y_pred: ArrayLike + Predicted target values. Returns ------- NDArray - Unsigned conformity scores. + Signed conformity scores. """ @abstractmethod def get_estimation_distribution( self, + X: ArrayLike, y_pred: ArrayLike, - conformity_scores: ArrayLike, + values: ArrayLike ) -> NDArray: """ - Placeholder for ``get_estimation_distribution``. + Placeholder for ``get_signed_conformity_scores``. Subclasses should implement this method! Compute samples of the estimation distribution from the predicted - values and the conformity scores. + targets and ``values``that can be either the conformity scores or + the conformity scores aggregated with the predictions. Parameters ---------- - y_pred: NDArray - Predicted values. + X: ArrayLike + Observed feature values. - conformity_scores: NDArray - Conformity scores. + y_pred: ArrayLike + Predicted values, it can be any type of predictions + (multi, low, up, ...). + + values: ArrayLike + Either the conformity scores or the conformity scores aggregated + with the predictions according to the subclass formula. Returns ------- - NDArray - Observed values. + ArrayLike """ def check_consistency( self, + X: ArrayLike, y: ArrayLike, y_pred: ArrayLike, conformity_scores: ArrayLike, @@ -114,16 +130,24 @@ def check_consistency( The following equality should be verified: ``self.get_estimation_distribution( - y_pred, self.get_conformity_scores(y, y_pred) + X, + y_pred, + self.get_conformity_scores(X, y, y_pred) ) == y`` Parameters ---------- - y: NDArray - Observed values. + X: ArrayLike + Observed feature values. - y_pred: NDArray - Predicted values. + y: ArrayLike + Observed target values. + + y_pred: ArrayLike + Predicted target values. + + conformity_scores: ArrayLike + Conformity scores. Raises ------ @@ -131,7 +155,7 @@ def check_consistency( If the two methods are not consistent. """ score_distribution = self.get_estimation_distribution( - y_pred, conformity_scores + X, y_pred, conformity_scores ) abs_conformity_scores = np.abs(np.subtract(score_distribution, y)) max_conf_score = np.max(abs_conformity_scores) @@ -141,8 +165,8 @@ def check_consistency( "get_estimation_distribution of the ConformityScore class " "are not consistent. " "The following equation must be verified: " - "self.get_estimation_distribution(y_pred, " - "self.get_conformity_scores(y, y_pred)) == y. " # noqa: E501 + "self.get_estimation_distribution(X, y_pred, " + "self.get_conformity_scores(X, y, y_pred)) == y" # noqa: E501 f"The maximum conformity score is {max_conf_score}." "The eps attribute may need to be increased if you are " "sure that the two methods are consistent." @@ -150,6 +174,7 @@ def check_consistency( def get_conformity_scores( self, + X: ArrayLike, y: ArrayLike, y_pred: ArrayLike, ) -> NDArray: @@ -158,20 +183,151 @@ def get_conformity_scores( Parameters ---------- + X: NDArray + Observed feature values. + y: NDArray - Observed values. + Observed target values. y_pred: NDArray - Predicted values. + Predicted target values. Returns ------- NDArray Conformity scores. """ - conformity_scores = self.get_signed_conformity_scores(y, y_pred) + conformity_scores = self.get_signed_conformity_scores(X, y, y_pred) if self.consistency_check: - self.check_consistency(y, y_pred, conformity_scores) + self.check_consistency(X, y, y_pred, conformity_scores) if self.sym: conformity_scores = np.abs(conformity_scores) return conformity_scores + + @staticmethod + def _get_quantile( + values: NDArray, + alpha_np: NDArray, + axis: int, + method: str + ) -> NDArray: + """ + Compute the alpha quantile of the conformity scores considering + the symmetrical property if so. + + Parameters + ---------- + values: NDArray + Values from which the quantile is computed, it can be the + conformity scores or the conformity scores aggregated with + the predictions. + + alpha_np: NDArray + NDArray of floats between ``0`` and ``1``, represents the + uncertainty of the confidence interval. + + axis: int + The axis from which to compute the quantile. + + method: str + ``"higher"`` or ``"lower"`` the method to compute the quantile. + + Returns + ------- + NDArray + Lower and upper quantile of the prediction intervals. + These quantiles are identical if the score is not symmetrical. + """ + quantile = np.column_stack([ + np_nanquantile( + values.astype(float), + _alpha, + axis=axis, + method=method + ) + for _alpha in alpha_np + ]) + return quantile + + def get_bounds( + self, + X: ArrayLike, + estimator: RegressorMixin, + conformity_scores: NDArray, + alpha_np: NDArray, + ensemble: bool, + method: str + ) -> Tuple[NDArray, NDArray, NDArray]: + """ + Compute bounds of the prediction intervals from the observed values, + the estimator of MapieRegressor and the conformity scores. + + Parameters + ---------- + X: ArrayLike + Observed feature values. + + estimator: RegressorMixin + Estimator that is fitted to predict y from X. + + conformity_scores: ArrayLike + Conformity scores. + + alpha_np: NDArray + NDArray of floats between ``0`` and ``1``, represents the + uncertainty of the confidence interval. + + ensemble: bool + Boolean determining whether the predictions are ensembled or not. + + method: str + The method parameter of MapieRegressor. + + Returns + ------- + Tuple[NDArray, NDArray, NDArray] + - The predictions itself. (y_pred) + - The lower bounds of the prediction intervals. + - The upper bounds of the prediction intervals. + """ + y_pred, y_pred_low, y_pred_up = estimator.predict(X, ensemble) + + signed = -1 if self.sym else 1 + alpha_low = alpha_np if self.sym else alpha_np / 2 + alpha_up = 1 - alpha_np if self.sym else 1 - alpha_np / 2 + + if method == "plus": + bound_low = self._get_quantile( + self.get_estimation_distribution( + X, y_pred_low, signed * conformity_scores + ), + alpha_low, + axis=1, + method="lower" + ) + bound_up = self._get_quantile( + self.get_estimation_distribution( + X, y_pred_up, conformity_scores + ), + alpha_up, + axis=1, + method="higher" + ) + else: + quantile_search = "higher" if self.sym else "lower" + alpha_low = 1 - alpha_np if self.sym else alpha_np / 2 + + quantile_low = self._get_quantile( + conformity_scores, alpha_low, axis=0, method=quantile_search + ) + quantile_up = self._get_quantile( + conformity_scores, alpha_up, axis=0, method="higher" + ) + bound_low = self.get_estimation_distribution( + X, y_pred_low, signed * quantile_low + ) + bound_up = self.get_estimation_distribution( + X, y_pred_up, quantile_up + ) + + return y_pred, bound_low, bound_up diff --git a/mapie/conformity_scores/residual_conformity_scores.py b/mapie/conformity_scores/residual_conformity_scores.py index 033baa59..6d317d29 100644 --- a/mapie/conformity_scores/residual_conformity_scores.py +++ b/mapie/conformity_scores/residual_conformity_scores.py @@ -2,7 +2,6 @@ from mapie._machine_precision import EPSILON from mapie._typing import ArrayLike, NDArray - from mapie.conformity_scores import ConformityScore @@ -24,28 +23,33 @@ def __init__( def get_signed_conformity_scores( self, + X: ArrayLike, y: ArrayLike, y_pred: ArrayLike, ) -> NDArray: """ - Compute the signed conformity scores from the predicted values - and the observed ones, from the following formula: + Compute the signed conformity scores from the observed values + and the estimator, from the following formula: signed conformity score = y - y_pred """ return np.subtract(y, y_pred) def get_estimation_distribution( self, + X: ArrayLike, y_pred: ArrayLike, - conformity_scores: ArrayLike, - ) -> NDArray: + values: ArrayLike + ): """ Compute samples of the estimation distribution from the predicted - values and the conformity scores, from the following formula: + targets and ``values``, from the following formula: signed conformity score = y - y_pred <=> y = y_pred + signed conformity score + + ``values`` can be either the conformity scores or + the conformity scores aggregated with the predictions. """ - return np.add(y_pred, conformity_scores) + return np.add(y_pred, values) class GammaConformityScore(ConformityScore): @@ -89,22 +93,21 @@ def _check_predicted_data( "in conformity with the Gamma distribution support." ) + @staticmethod def _all_strictly_positive( - self, y: ArrayLike, ) -> bool: - if np.any(np.less_equal(y, 0)): - return False - return True + return not np.any(np.less_equal(y, 0)) def get_signed_conformity_scores( self, + X: ArrayLike, y: ArrayLike, y_pred: ArrayLike, ) -> NDArray: """ - Compute samples of the estimation distribution from the predicted - values and the conformity scores, from the following formula: + Compute the signed conformity scores from the observed values + and the estimator, from the following formula: signed conformity score = (y - y_pred) / y_pred """ self._check_observed_data(y) @@ -113,14 +116,18 @@ def get_signed_conformity_scores( def get_estimation_distribution( self, + X: ArrayLike, y_pred: ArrayLike, - conformity_scores: ArrayLike, + values: ArrayLike, ) -> NDArray: """ Compute samples of the estimation distribution from the predicted - values and the conformity scores, from the following formula: + targets and ``values``, from the following formula: signed conformity score = (y - y_pred) / y_pred <=> y = y_pred * (1 + signed conformity score) + + ``values`` can be either the conformity scores or + the conformity scores aggregated with the predictions. """ self._check_predicted_data(y_pred) - return np.multiply(y_pred, np.add(1, conformity_scores)) + return np.multiply(y_pred, np.add(1, values)) diff --git a/mapie/regression/estimator.py b/mapie/regression/estimator.py new file mode 100644 index 00000000..7db6aba3 --- /dev/null +++ b/mapie/regression/estimator.py @@ -0,0 +1,504 @@ +from __future__ import annotations + +from typing import List, Optional, Tuple, Union, cast + +import numpy as np +from joblib import Parallel, delayed +from sklearn.base import RegressorMixin, clone +from sklearn.model_selection import BaseCrossValidator, ShuffleSplit +from sklearn.utils import _safe_indexing +from sklearn.utils.validation import (_num_samples, check_is_fitted) + +from mapie._typing import ArrayLike, NDArray +from mapie.aggregation_functions import aggregate_all, phi2D +from mapie.utils import (check_nan_in_aposteriori_prediction, + fit_estimator) + + +class EnsembleRegressor(RegressorMixin): + """ + This class implements methods to handle the training and usage of the + estimator. This estimator can be unique or composed by cross validated + estimators. + + Parameters + ---------- + estimator: Optional[RegressorMixin] + Any regressor with scikit-learn API + (i.e. with ``fit`` and ``predict`` methods). + If ``None``, estimator defaults to a ``LinearRegression`` instance. + + By default ``None``. + + method: str + Method to choose for prediction interval estimates. + Choose among: + + - ``"naive"``, based on training set conformity scores, + - ``"base"``, based on validation sets conformity scores, + - ``"plus"``, based on validation conformity scores and + testing predictions, + - ``"minmax"``, based on validation conformity scores and + testing predictions (min/max among cross-validation clones). + + By default ``"plus"``. + + cv: Optional[Union[int, str, BaseCrossValidator]] + The cross-validation strategy for computing conformity scores. + It directly drives the distinction between jackknife and cv variants. + Choose among: + + - ``None``, to use the default 5-fold cross-validation + - integer, to specify the number of folds. + If equal to ``-1``, equivalent to + ``sklearn.model_selection.LeaveOneOut()``. + - CV splitter: any ``sklearn.model_selection.BaseCrossValidator`` + Main variants are: + - ``sklearn.model_selection.LeaveOneOut`` (jackknife), + - ``sklearn.model_selection.KFold`` (cross-validation), + - ``subsample.Subsample`` object (bootstrap). + - ``"split"``, does not involve cross-validation but a division + of the data into training and calibration subsets. The splitter + used is the following: ``sklearn.model_selection.ShuffleSplit``. + - ``"prefit"``, assumes that ``estimator`` has been fitted already, + and the ``method`` parameter is ignored. + All data provided in the ``fit`` method is then used + for computing conformity scores only. + At prediction time, quantiles of these conformity scores are used + to provide a prediction interval with fixed width. + The user has to take care manually that data for model fitting and + conformity scores estimate are disjoint. + + By default ``None``. + + test_size: Optional[Union[int, float]] + If ``float``, should be between ``0.0`` and ``1.0`` and represent the + proportion of the dataset to include in the test split. If ``int``, + represents the absolute number of test samples. If ``None``, + it will be set to ``0.1``. + + If cv is not ``"split"``, ``test_size`` is ignored. + + By default ``None``. + + n_jobs: Optional[int] + Number of jobs for parallel processing using joblib + via the "locky" backend. + If ``-1`` all CPUs are used. + If ``1`` is given, no parallel computing code is used at all, + which is useful for debugging. + For ``n_jobs`` below ``-1``, ``(n_cpus + 1 - n_jobs)`` are used. + ``None`` is a marker for `unset` that will be interpreted as + ``n_jobs=1`` (sequential execution). + + By default ``None``. + + agg_function: Optional[str] + Determines how to aggregate predictions from perturbed models, both at + training and prediction time. + + If ``None``, it is ignored except if ``cv`` class is ``Subsample``, + in which case an error is raised. + If ``"mean"`` or ``"median"``, returns the mean or median of the + predictions computed from the out-of-folds models. + Note: if you plan to set the ``ensemble`` argument to ``True`` in the + ``predict`` method, you have to specify an aggregation function. + Otherwise an error would be raised. + + The Jackknife+ interval can be interpreted as an interval around the + median prediction, and is guaranteed to lie inside the interval, + unlike the single estimator predictions. + + When the cross-validation strategy is ``Subsample`` (i.e. for the + Jackknife+-after-Bootstrap method), this function is also used to + aggregate the training set in-sample predictions. + + If ``cv`` is ``"prefit"`` or ``"split"``, ``agg_function`` is ignored. + + By default ``"mean"``. + + verbose: int + The verbosity level, used with joblib for multiprocessing. + The frequency of the messages increases with the verbosity level. + If it more than ``10``, all iterations are reported. + Above ``50``, the output is sent to stdout. + + By default ``0``. + + random_state: Optional[Union[int, RandomState]] + Pseudo random number generator state used for random sampling. + Pass an int for reproducible output across multiple function calls. + + By default ``None``. + + Attributes + ---------- + single_estimator_: sklearn.RegressorMixin + Estimator fitted on the whole training set. + + estimators_: list + List of out-of-folds estimators. + + k_: ArrayLike + - Array of nans, of shape (len(y), 1) if ``cv`` is ``"prefit"`` + (defined but not used) + - Dummy array of folds containing each training sample, otherwise. + Of shape (n_samples_train, cv.get_n_splits(X_train, y_train)). + """ + no_agg_cv_ = ["prefit", "split"] + no_agg_methods_ = ["naive", "base"] + fit_attributes = [ + "single_estimator_", + "estimators_", + "k_", + ] + + def __init__( + self, + estimator: Optional[RegressorMixin], + method: str, + cv: Optional[Union[int, str, BaseCrossValidator]], + agg_function: Optional[str], + n_jobs: Optional[int], + random_state: Optional[Union[int, np.random.RandomState]], + test_size: Optional[Union[int, float]], + verbose: int + ): + self.estimator = estimator + self.method = method + self.cv = cv + self.agg_function = agg_function + self.n_jobs = n_jobs + self.random_state = random_state + self.test_size = test_size + self.verbose = verbose + + @staticmethod + def _fit_oof_estimator( + estimator: RegressorMixin, + X: ArrayLike, + y: ArrayLike, + train_index: ArrayLike, + sample_weight: Optional[ArrayLike] = None, + ) -> RegressorMixin: + """ + Fit a single out-of-fold model on a given training set. + + Parameters + ---------- + estimator: RegressorMixin + Estimator to train. + + X: ArrayLike of shape (n_samples, n_features) + Input data. + + y: ArrayLike of shape (n_samples,) + Input labels. + + train_index: ArrayLike of shape (n_samples_train) + Training data indices. + + sample_weight: Optional[ArrayLike] of shape (n_samples,) + Sample weights. If None, then samples are equally weighted. + By default ``None``. + + Returns + ------- + Tuple[RegressorMixin, NDArray, ArrayLike] + + - [0]: RegressorMixin, fitted estimator + - [1]: NDArray of shape (n_samples_val,), + estimator predictions on the validation fold. + - [2]: ArrayLike of shape (n_samples_val,), + validation data indices. + """ + X_train = _safe_indexing(X, train_index) + y_train = _safe_indexing(y, train_index) + if not (sample_weight is None): + sample_weight = _safe_indexing(sample_weight, train_index) + sample_weight = cast(NDArray, sample_weight) + + estimator = fit_estimator( + estimator, X_train, y_train, sample_weight=sample_weight + ) + return estimator + + def fit( + self, + X: ArrayLike, + y: ArrayLike, + sample_weight: Optional[ArrayLike] = None, + ) -> EnsembleRegressor: + """ + Fit the base estimator under the ``single_estimator_`` attribute. + Fit all cross-validated estimator clones + and rearrange them into a list, the ``estimators_`` attribute. + Out-of-fold conformity scores are stored under + the ``conformity_scores_`` attribute. + + Parameters + ---------- + X: ArrayLike of shape (n_samples, n_features) + Input data. + + y: ArrayLike of shape (n_samples,) + Input labels. + + sample_weight: Optional[ArrayLike] of shape (n_samples,) + Sample weights. If None, then samples are equally weighted. + By default ``None``. + + Returns + ------- + EnsembleRegressor + The estimator fitted. + + """ + # Initialization + single_estimator_: RegressorMixin + estimators_: List[RegressorMixin] = [] + full_indexes = np.arange(_num_samples(X)) + cv = self.cv + estimator = self.estimator + n_samples = _num_samples(y) + + # Computation + if cv == "prefit": + single_estimator_ = estimator + self.k_ = np.full( + shape=(n_samples, 1), fill_value=np.nan, dtype=float + ) + else: + single_estimator_ = self._fit_oof_estimator( + clone(estimator), X, y, full_indexes, sample_weight + ) + cv = cast(BaseCrossValidator, cv) + self.k_ = np.full( + shape=(n_samples, cv.get_n_splits(X, y)), + fill_value=np.nan, + dtype=float, + ) + if self.method == "naive": + estimators_ = [single_estimator_] + else: + estimators_ = Parallel(self.n_jobs, verbose=self.verbose)( + delayed(self._fit_oof_estimator)( + clone(estimator), X, y, train_index, sample_weight + ) + for train_index, _ in cv.split(X) + ) + if isinstance(cv, ShuffleSplit): + single_estimator_ = estimators_[0] + + self.single_estimator_ = single_estimator_ + self.estimators_ = estimators_ + + return self + + def _aggregate_with_mask( + self, + x: NDArray, + k: NDArray + ) -> NDArray: + """ + Take the array of predictions, made by the refitted estimators, + on the testing set, and the 1-or-nan array indicating for each training + sample which one to integrate, and aggregate to produce phi-{t}(x_t) + for each training sample x_t. + + Parameters: + ----------- + x: ArrayLike of shape (n_samples_test, n_estimators) + Array of predictions, made by the refitted estimators, + for each sample of the testing set. + + k: ArrayLike of shape (n_samples_training, n_estimators) + 1-or-nan array: indicates whether to integrate the prediction + of a given estimator into the aggregation, for each training + sample. + + Returns: + -------- + ArrayLike of shape (n_samples_test,) + Array of aggregated predictions for each testing sample. + """ + if self.method in self.no_agg_methods_ or self.cv in self.no_agg_cv_: + raise ValueError( + "There should not be aggregation of predictions " + f"if cv is in '{self.no_agg_cv_}' " + f"or if method is in '{self.no_agg_methods_}'." + ) + elif self.agg_function == "median": + return phi2D(A=x, B=k, fun=lambda x: np.nanmedian(x, axis=1)) + # To aggregate with mean() the aggregation coud be done + # with phi2D(A=x, B=k, fun=lambda x: np.nanmean(x, axis=1). + # However, phi2D contains a np.apply_along_axis loop which + # is much slower than the matrices multiplication that can + # be used to compute the means. + elif self.agg_function in ["mean", None]: + K = np.nan_to_num(k, nan=0.0) + return np.matmul(x, (K / (K.sum(axis=1, keepdims=True))).T) + else: + raise ValueError("The value of self.agg_function is not correct") + + @staticmethod + def _predict_oof_estimator( + estimator: RegressorMixin, + X: ArrayLike, + val_index: ArrayLike, + ): + """ + Perform predictions on a single out-of-fold model on a validation set. + + Parameters + ---------- + estimator: RegressorMixin + Estimator to train. + + X: ArrayLike of shape (n_samples, n_features) + Input data. + + val_index: ArrayLike of shape (n_samples_val) + Validation data indices. + + Returns + ------- + Tuple[NDArray, ArrayLike] + Predictions of estimator from val_index of X. + """ + X_val = _safe_indexing(X, val_index) + if _num_samples(X_val) > 0: + y_pred = estimator.predict(X_val) + else: + y_pred = np.array([]) + return y_pred, val_index + + def _pred_multi(self, X: ArrayLike) -> NDArray: + """ + Return a prediction per train sample for each test sample, by + aggregation with matrix ``k_``. + + Parameters + ---------- + X: ArrayLike of shape (n_samples_test, n_features) + Input data + + Returns + ------- + NDArray of shape (n_samples_test, n_samples_train) + """ + y_pred_multi = np.column_stack( + [e.predict(X) for e in self.estimators_] + ) + # At this point, y_pred_multi is of shape + # (n_samples_test, n_estimators_). The method + # ``_aggregate_with_mask`` fits it to the right size + # thanks to the shape of k_. + y_pred_multi = self._aggregate_with_mask(y_pred_multi, self.k_) + return y_pred_multi + + def predict_calib(self, X: ArrayLike) -> NDArray: + """ + Perform predictions on X : the calibration set. This method is + called in the ConformityScore class to compute the conformity scores. + + Parameters + ---------- + X: ArrayLike of shape (n_samples_test, n_features) + Input data + + Returns + ------- + NDArray of shape (n_samples_test, 1) + The predictions. + """ + check_is_fitted(self, self.fit_attributes) + + if self.cv == "prefit": + y_pred = self.single_estimator_.predict(X) + else: + if self.method == "naive": + y_pred = self.single_estimator_.predict(X) + else: + cv = cast(BaseCrossValidator, self.cv) + outputs = Parallel(n_jobs=self.n_jobs, verbose=self.verbose)( + delayed(self._predict_oof_estimator)( + estimator, X, calib_index, + ) + for (_, calib_index), estimator in zip(cv.split(X), + self.estimators_) + ) + predictions, indices = map( + list, zip(*outputs) + ) + n_samples = _num_samples(X) + pred_matrix = np.full( + shape=(n_samples, cv.get_n_splits(X)), + fill_value=np.nan, + dtype=float, + ) + for i, ind in enumerate(indices): + pred_matrix[ind, i] = np.array( + predictions[i], dtype=float + ) + self.k_[ind, i] = 1 + check_nan_in_aposteriori_prediction(pred_matrix) + + y_pred = aggregate_all(self.agg_function, pred_matrix) + + return y_pred + + def predict( + self, + X: ArrayLike, + ensemble: bool = False + ) -> Union[NDArray, Tuple[NDArray, NDArray, NDArray]]: + """ + Predict target from X. It also computes the prediction per train sample + for each test sample according to ``self.method``. + + Parameters + ---------- + X: ArrayLike of shape (n_samples, n_features) + Test data. + + ensemble: bool + Boolean determining whether the predictions are ensembled or not. + If ``False``, predictions are those of the model trained on the + whole training set. + If ``True``, predictions from perturbed models are aggregated by + the aggregation function specified in the ``agg_function`` + attribute. + + If ``cv`` is ``"prefit"`` or ``"split"``, ``ensemble`` is ignored. + + By default ``False``. + + Returns + ------- + Tuple[NDArray, NDArray, NDArray] + - Predictions + - The multiple predictions for the lower bound of the intervals. + - The multiple predictions for the upper bound of the intervals. + """ + + check_is_fitted(self, self.fit_attributes) + + y_pred = self.single_estimator_.predict(X) + + if self.method in self.no_agg_methods_ or self.cv in self.no_agg_cv_: + y_pred_multi_low = y_pred[:, np.newaxis] + y_pred_multi_up = y_pred[:, np.newaxis] + else: + y_pred_multi = self._pred_multi(X) + + if self.method == "minmax": + y_pred_multi_low = np.min(y_pred_multi, axis=1, keepdims=True) + y_pred_multi_up = np.max(y_pred_multi, axis=1, keepdims=True) + else: + y_pred_multi_low = y_pred_multi + y_pred_multi_up = y_pred_multi + + if ensemble: + y_pred = aggregate_all(self.agg_function, y_pred_multi) + return y_pred, y_pred_multi_low, y_pred_multi_up diff --git a/mapie/regression/regression.py b/mapie/regression/regression.py index a7f17316..5e007255 100644 --- a/mapie/regression/regression.py +++ b/mapie/regression/regression.py @@ -1,26 +1,23 @@ from __future__ import annotations -from typing import Iterable, List, Optional, Tuple, Union, cast +from typing import Iterable, Optional, Tuple, Union, cast import numpy as np -from joblib import Parallel, delayed -from sklearn.base import BaseEstimator, RegressorMixin, clone +from sklearn.base import BaseEstimator, RegressorMixin from sklearn.linear_model import LinearRegression -from sklearn.model_selection import BaseCrossValidator, ShuffleSplit +from sklearn.model_selection import BaseCrossValidator from sklearn.pipeline import Pipeline -from sklearn.utils import _safe_indexing, check_random_state -from sklearn.utils.validation import (_check_y, _num_samples, check_is_fitted, +from sklearn.utils import check_random_state +from sklearn.utils.validation import (_check_y, check_is_fitted, indexable) -from mapie._compatibility import np_nanquantile from mapie._typing import ArrayLike, NDArray -from mapie.aggregation_functions import aggregate_all, phi2D from mapie.conformity_scores import ConformityScore +from .estimator import EnsembleRegressor from mapie.utils import (check_alpha, check_alpha_and_n_samples, check_conformity_score, check_cv, check_estimator_fit_predict, check_n_features_in, - check_n_jobs, check_nan_in_aposteriori_prediction, - check_null_weight, check_verbose, fit_estimator) + check_n_jobs, check_null_weight, check_verbose) class MapieRegressor(BaseEstimator, RegressorMixin): @@ -162,21 +159,12 @@ class MapieRegressor(BaseEstimator, RegressorMixin): valid_methods_: List[str] List of all valid methods. - single_estimator_: sklearn.RegressorMixin - Estimator fitted on the whole training set. - - estimators_: list - List of out-of-folds estimators. + estimator_: EnsembleRegressor + Sklearn estimator that handle all that is related to the estimator. conformity_scores_: ArrayLike of shape (n_samples_train,) Conformity scores between ``y_train`` and ``y_pred``. - k_: ArrayLike - - Array of nans, of shape (len(y), 1) if ``cv`` is ``"prefit"`` - (defined but not used) - - Dummy array of folds containing each training sample, otherwise. - Of shape (n_samples_train, cv.get_n_splits(X_train, y_train)). - n_features_in_: int Number of features passed to the ``fit`` method. @@ -220,9 +208,7 @@ class MapieRegressor(BaseEstimator, RegressorMixin): valid_agg_functions_ = [None, "median", "mean"] ensemble_agg_functions_ = ["median", "mean"] fit_attributes = [ - "single_estimator_", - "estimators_", - "k_", + "estimator_", "conformity_scores_", "conformity_score_function_", "n_features_in_", @@ -396,139 +382,37 @@ def _check_ensemble( f"in '{self.ensemble_agg_functions_}'." ) - def _fit_and_predict_oof_model( + def _check_fit_parameters( self, - estimator: RegressorMixin, X: ArrayLike, y: ArrayLike, - train_index: ArrayLike, - val_index: ArrayLike, sample_weight: Optional[ArrayLike] = None, - ) -> Tuple[RegressorMixin, NDArray, ArrayLike]: - """ - Fit a single out-of-fold model on a given training set and - perform predictions on a test set. - - Parameters - ---------- - estimator: RegressorMixin - Estimator to train. - - X: ArrayLike of shape (n_samples, n_features) - Input data. - - y: ArrayLike of shape (n_samples,) - Input labels. - - train_index: ArrayLike of shape (n_samples_train) - Training data indices. - - val_index: ArrayLike of shape (n_samples_val) - Validation data indices. - - sample_weight: Optional[ArrayLike] of shape (n_samples,) - Sample weights. If ``None``, then samples are equally weighted. - By default ``None``. - - Returns - ------- - Tuple[RegressorMixin, NDArray, ArrayLike] - - - [0]: RegressorMixin, fitted estimator - - [1]: NDArray of shape (n_samples_val,), - estimator predictions on the validation fold. - - [2]: ArrayLike of shape (n_samples_val,), - validation data indices. - """ - X_train = _safe_indexing(X, train_index) - y_train = _safe_indexing(y, train_index) - X_val = _safe_indexing(X, val_index) - if sample_weight is None: - estimator = fit_estimator(estimator, X_train, y_train) - else: - sample_weight_train = _safe_indexing(sample_weight, train_index) - estimator = fit_estimator( - estimator, X_train, y_train, sample_weight_train - ) - if _num_samples(X_val) > 0: - y_pred = estimator.predict(X_val) - else: - y_pred = np.array([]) - return estimator, y_pred, val_index - - def _aggregate_with_mask( - self, - x: NDArray, - k: NDArray - ) -> NDArray: - """ - Take the array of predictions, made by the refitted estimators, - on the testing set, and the 1-or-nan array indicating for each training - sample which one to integrate, and aggregate to produce phi-{t}(x_t) - for each training sample x_t. - - Parameters: - ----------- - x: ArrayLike of shape (n_samples_test, n_estimators) - Array of predictions, made by the refitted estimators, - for each sample of the testing set. - - k: ArrayLike of shape (n_samples_training, n_estimators) - 1-or-nan array: indicates whether to integrate the prediction - of a given estimator into the aggregation, for each training - sample. - - Returns: - -------- - ArrayLike of shape (n_samples_test,) - Array of aggregated predictions for each testing sample. - """ - if self.method in self.no_agg_methods_ \ - or self.cv in self.no_agg_cv_: - raise ValueError( - "There should not be aggregation of predictions " - f"if cv is in '{self.no_agg_cv_}' " - f"or if method is in '{self.no_agg_methods_}'." - ) - elif self.agg_function == "median": - return phi2D(A=x, B=k, fun=lambda x: np.nanmedian(x, axis=1)) - # To aggregate with mean() the aggregation coud be done - # with phi2D(A=x, B=k, fun=lambda x: np.nanmean(x, axis=1). - # However, phi2D contains a np.apply_along_axis loop which - # is much slower than the matrices multiplication that can - # be used to compute the means. - elif self.agg_function in ["mean", None]: - K = np.nan_to_num(k, nan=0.0) - return np.matmul(x, (K / (K.sum(axis=1, keepdims=True))).T) - else: - raise ValueError("The value of self.agg_function is not correct") - - def _pred_multi( - self, - X: ArrayLike - ) -> NDArray: - """ - Return a prediction per train sample for each test sample, by - aggregation with matrix ``k_``. - - Parameters - ---------- - X: NDArray of shape (n_samples_test, n_features) - Input data - - Returns - ------- - NDArray of shape (n_samples_test, n_samples_train) - """ - y_pred_multi = np.column_stack( - [e.predict(X) for e in self.estimators_] + ): + # Checking + self._check_parameters() + cv = check_cv( + self.cv, test_size=self.test_size, random_state=self.random_state + ) + estimator = self._check_estimator(self.estimator) + agg_function = self._check_agg_function(self.agg_function) + cs_estimator = check_conformity_score( + self.conformity_score ) - # At this point, y_pred_multi is of shape - # (n_samples_test, n_estimators_). The method - # ``_aggregate_with_mask`` fits it to the right size - # thanks to the shape of k_. - y_pred_multi = self._aggregate_with_mask(y_pred_multi, self.k_) - return y_pred_multi + X, y = indexable(X, y) + y = _check_y(y) + sample_weight, X, y = check_null_weight(sample_weight, X, y) + self.n_features_in_ = check_n_features_in(X) + + # Casting + cv = cast(BaseCrossValidator, cv) + estimator = cast(RegressorMixin, estimator) + cs_estimator = cast(ConformityScore, cs_estimator) + agg_function = cast(Optional[str], agg_function) + X = cast(NDArray, X) + y = cast(NDArray, y) + sample_weight = cast(Optional[NDArray], sample_weight) + + return estimator, cs_estimator, agg_function, cv, X, y, sample_weight def fit( self, @@ -539,11 +423,9 @@ def fit( """ Fit estimator and compute conformity scores used for prediction intervals. - Fit the base estimator under the ``single_estimator_`` attribute. - Fit all cross-validated estimator clones - and rearrange them into a list, the ``estimators_`` attribute. - Out-of-fold conformity scores are stored under - the ``conformity_scores_`` attribute. + + All the types of estimator (single or cross validated ones) are + encapsulated under EnsembleRegressor. Parameters ---------- @@ -570,84 +452,34 @@ def fit( The model itself. """ # Checks - self._check_parameters() - cv = check_cv( - self.cv, test_size=self.test_size, random_state=self.random_state - ) - estimator = self._check_estimator(self.estimator) - agg_function = self._check_agg_function(self.agg_function) - X, y = indexable(X, y) - y = _check_y(y) - sample_weight = cast(Optional[NDArray], sample_weight) - self.n_features_in_ = check_n_features_in(X, cv, estimator) - sample_weight, X, y = check_null_weight(sample_weight, X, y) - self.conformity_score_function_ = check_conformity_score( - self.conformity_score + (estimator, + self.conformity_score_function_, + agg_function, + cv, + X, + y, + sample_weight) = self._check_fit_parameters(X, y, sample_weight) + + self.estimator_ = EnsembleRegressor( + estimator, + self.method, + cv, + agg_function, + self.n_jobs, + self.random_state, + self.test_size, + self.verbose ) - y = cast(NDArray, y) - n_samples = _num_samples(y) - - # Initialization - self.estimators_: List[RegressorMixin] = [] - - # Work - if cv == "prefit": - self.single_estimator_ = estimator - y_pred = self.single_estimator_.predict(X) - self.k_ = np.full( - shape=(n_samples, 1), fill_value=np.nan, dtype=float - ) - else: - cv = cast(BaseCrossValidator, cv) - self.k_ = np.full( - shape=(n_samples, cv.get_n_splits(X, y)), - fill_value=np.nan, - dtype=float, - ) - - self.single_estimator_ = fit_estimator( - clone(estimator), X, y, sample_weight + # Fit the prediction function + self.estimator_ = self.estimator_.fit(X, y, sample_weight) + y_pred = self.estimator_.predict_calib(X) + + # Compute the conformity scores (manage jk-ab case) + self.conformity_scores_ = \ + self.conformity_score_function_.get_conformity_scores( + X, y, y_pred ) - if self.method == "naive": - y_pred = self.single_estimator_.predict(X) - else: - outputs = Parallel(n_jobs=self.n_jobs, verbose=self.verbose)( - delayed(self._fit_and_predict_oof_model)( - clone(estimator), - X, - y, - train_index, - val_index, - sample_weight, - ) - for train_index, val_index in cv.split(X) - ) - self.estimators_, predictions, val_indices = map( - list, zip(*outputs) - ) - - pred_matrix = np.full( - shape=(n_samples, cv.get_n_splits(X, y)), - fill_value=np.nan, - dtype=float, - ) - for i, val_ind in enumerate(val_indices): - pred_matrix[val_ind, i] = np.array( - predictions[i], dtype=float - ) - self.k_[val_ind, i] = 1 - check_nan_in_aposteriori_prediction(pred_matrix) - - y_pred = aggregate_all(agg_function, pred_matrix) - - self.conformity_scores_ = ( - self.conformity_score_function_.get_conformity_scores(y, y_pred) - ) - - if isinstance(cv, ShuffleSplit): - self.single_estimator_ = self.estimators_[0] - return self def predict( @@ -708,74 +540,22 @@ def predict( self._check_ensemble(ensemble) alpha = cast(Optional[NDArray], check_alpha(alpha)) - y_pred = self.single_estimator_.predict(X) - n = len(self.conformity_scores_) - - if alpha is None: - return np.array(y_pred) - - alpha_np = cast(NDArray, alpha) - check_alpha_and_n_samples(alpha_np, n) - - if self.method in self.no_agg_methods_ \ - or self.cv in self.no_agg_cv_: - y_pred_multi_low = y_pred[:, np.newaxis] - y_pred_multi_up = y_pred[:, np.newaxis] - else: - y_pred_multi = self._pred_multi(X) - - if self.method == "minmax": - y_pred_multi_low = np.min(y_pred_multi, axis=1, keepdims=True) - y_pred_multi_up = np.max(y_pred_multi, axis=1, keepdims=True) - else: - y_pred_multi_low = y_pred_multi - y_pred_multi_up = y_pred_multi - - if ensemble: - y_pred = aggregate_all(self.agg_function, y_pred_multi) - - # compute distributions of lower and upper bounds - if self.conformity_score_function_.sym: - conformity_scores_low = -self.conformity_scores_ - conformity_scores_up = self.conformity_scores_ - else: - conformity_scores_low = self.conformity_scores_ - conformity_scores_up = self.conformity_scores_ - alpha_np = alpha_np / 2 - - lower_bounds = ( - self.conformity_score_function_.get_estimation_distribution( - y_pred_multi_low, conformity_scores_low - ) - ) - upper_bounds = ( - self.conformity_score_function_.get_estimation_distribution( - y_pred_multi_up, conformity_scores_up - ) - ) - - # get desired confidence intervals according to alpha - y_pred_low = np.column_stack( - [ - np_nanquantile( - lower_bounds.astype(float), - _alpha, - axis=1, - method="lower", - ) - for _alpha in alpha_np - ] - ) - y_pred_up = np.column_stack( - [ - np_nanquantile( - upper_bounds.astype(float), - 1 - _alpha, - axis=1, - method="higher", + if not (alpha is None): + n = len(self.conformity_scores_) + alpha_np = cast(NDArray, alpha) + check_alpha_and_n_samples(alpha_np, n) + + y_pred, bound_low, bound_up = \ + self.conformity_score_function_.get_bounds( + X, + self.estimator_, + self.conformity_scores_, + alpha_np, + ensemble, + self.method ) - for _alpha in alpha_np - ] - ) + return y_pred, np.stack([bound_low, bound_up], axis=1) - return y_pred, np.stack([y_pred_low, y_pred_up], axis=1) + else: + y_pred, _, _ = self.estimator_.predict(X, ensemble) + return y_pred From 3f42406e6283fe7b56151ec929d53ad48dd70895 Mon Sep 17 00:00:00 2001 From: Candice Moyet Date: Thu, 6 Jul 2023 09:59:35 +0200 Subject: [PATCH 02/30] ENH: tests for refacto of MapieRegressor and ConformityScore --- mapie/regression/__init__.py | 4 +- mapie/regression/time_series_regression.py | 4 +- mapie/tests/test_common.py | 7 +- mapie/tests/test_conformity_scores.py | 103 ++++++++++++++------- mapie/tests/test_regression.py | 32 ++++--- mapie/tests/test_time_series_regression.py | 7 +- 6 files changed, 106 insertions(+), 51 deletions(-) diff --git a/mapie/regression/__init__.py b/mapie/regression/__init__.py index 16243ace..7360f618 100644 --- a/mapie/regression/__init__.py +++ b/mapie/regression/__init__.py @@ -1,9 +1,11 @@ from .regression import MapieRegressor from .quantile_regression import MapieQuantileRegressor from .time_series_regression import MapieTimeSeriesRegressor +from .estimator import EnsembleRegressor __all__ = [ "MapieRegressor", "MapieQuantileRegressor", - "MapieTimeSeriesRegressor" + "MapieTimeSeriesRegressor", + "EnsembleRegressor" ] diff --git a/mapie/regression/time_series_regression.py b/mapie/regression/time_series_regression.py index 080eb3e4..9e7354b9 100644 --- a/mapie/regression/time_series_regression.py +++ b/mapie/regression/time_series_regression.py @@ -283,7 +283,7 @@ def predict( check_is_fitted(self, self.fit_attributes) self._check_ensemble(ensemble) alpha = cast(Optional[NDArray], check_alpha(alpha)) - y_pred = self.single_estimator_.predict(X) + y_pred = self.estimator_.single_estimator_.predict(X) n = len(self.conformity_scores_) if alpha is None: @@ -321,7 +321,7 @@ def predict( y_pred_low = y_pred[:, np.newaxis] + lower_quantiles y_pred_up = y_pred[:, np.newaxis] + higher_quantiles else: - y_pred_multi = self._pred_multi(X) + y_pred_multi = self.estimator_._pred_multi(X) pred = aggregate_all(self.agg_function, y_pred_multi) lower_bounds, upper_bounds = pred, pred diff --git a/mapie/tests/test_common.py b/mapie/tests/test_common.py index 0f199982..45379bc2 100644 --- a/mapie/tests/test_common.py +++ b/mapie/tests/test_common.py @@ -108,7 +108,12 @@ def test_none_estimator(pack: Tuple[BaseEstimator, BaseEstimator]) -> None: MapieEstimator, DefaultEstimator = pack mapie_estimator = MapieEstimator(estimator=None) mapie_estimator.fit(X_toy, y_toy) - assert isinstance(mapie_estimator.single_estimator_, DefaultEstimator) + if isinstance(mapie_estimator, MapieClassifier): + assert isinstance(mapie_estimator.single_estimator_, DefaultEstimator) + if isinstance(mapie_estimator, MapieRegressor): + assert isinstance( + mapie_estimator.estimator_.single_estimator_, DefaultEstimator + ) @pytest.mark.parametrize("estimator", [0, "a", KFold(), ["a", "b"]]) diff --git a/mapie/tests/test_conformity_scores.py b/mapie/tests/test_conformity_scores.py index 7441302f..ce27e227 100644 --- a/mapie/tests/test_conformity_scores.py +++ b/mapie/tests/test_conformity_scores.py @@ -1,15 +1,22 @@ import numpy as np import pytest -from mapie._typing import ArrayLike, NDArray +from mapie._typing import NDArray, ArrayLike from mapie.conformity_scores import (AbsoluteConformityScore, ConformityScore, GammaConformityScore) +from mapie.regression import EnsembleRegressor +from sklearn.linear_model import LinearRegression +from sklearn.model_selection import KFold -X_toy = np.array([0, 1, 2, 3, 4, 5]).reshape(-1, 1) -y_toy = np.array([5, 7, 9, 11, 13, 15]) -y_pred_list = [4, 7, 10, 12, 13, 12] -conf_scores_list = [1, 0, -1, -1, 0, 3] -conf_scores_gamma_list = [1 / 4, 0, -1 / 10, -1 / 12, 0, 3 / 12] +X_toy_train = np.array([0, 1, 2, 3, 4, 5]).reshape(-1, 1) +y_toy_train = np.array([5, 7, 9, 11, 13, 15]) +X_toy_test = np.array([6, 9, 10, 2, 4, 5]).reshape(-1, 1) +y_toy_test = np.array([15, 4, 90, 1, 15, 1]) +y_pred_list = [17., 23., 25., 9., 13., 15.] +conf_scores_list = [-2., -19., 65., -8., 2., -14.] +conf_scores_gamma_list = [-0.11764706, -0.82608696, 2.6, + -0.88888889, 0.15384615, -0.93333333] +random_state = 42 class DummyConformityScore(ConformityScore): @@ -17,19 +24,32 @@ def __init__(self) -> None: super().__init__(sym=True, consistency_check=True) def get_signed_conformity_scores( - self, y: ArrayLike, y_pred: ArrayLike, + self, X: ArrayLike, y: ArrayLike, y_pred: ArrayLike, ) -> NDArray: return np.subtract(y, y_pred) def get_estimation_distribution( - self, y_pred: ArrayLike, conformity_scores: ArrayLike + self, X: ArrayLike, y_pred: ArrayLike, values: ArrayLike ) -> NDArray: """ A positive constant is added to the sum between predictions and conformity scores to make the estimated distribution inconsistent with the conformity score. """ - return np.add(y_pred, conformity_scores) + 1 + return np.add(y_pred, values) + 1 + + +estimator_toy = EnsembleRegressor( + LinearRegression(), + "plus", + KFold(n_splits=5, random_state=None, shuffle=True), + "mean", + None, + random_state, + 0.20, + False +) +estimator_toy_fitted = estimator_toy.fit(X_toy_train, y_toy_train) @pytest.mark.parametrize("sym", [False, True]) @@ -45,10 +65,12 @@ def test_absolute_conformity_score_get_conformity_scores( """Test conformity score computation for AbsoluteConformityScore.""" abs_conf_score = AbsoluteConformityScore() signed_conf_scores = abs_conf_score.get_signed_conformity_scores( - y_toy, y_pred + X_toy_test, y_toy_test, y_pred + ) + conf_scores = abs_conf_score.get_conformity_scores( + X_toy_test, y_toy_test, y_pred ) - conf_scores = abs_conf_score.get_conformity_scores(y_toy, y_pred) - expected_signed_conf_scores = np.array([1, 0, -1, -1, 0, 3]) + expected_signed_conf_scores = np.array(conf_scores_list) expected_conf_scores = np.abs(expected_signed_conf_scores) np.testing.assert_allclose(signed_conf_scores, expected_signed_conf_scores) np.testing.assert_allclose(conf_scores, expected_conf_scores) @@ -63,8 +85,10 @@ def test_absolute_conformity_score_get_estimation_distribution( ) -> None: """Test conformity observed value computation for AbsoluteConformityScore.""" # noqa: E501 abs_conf_score = AbsoluteConformityScore() - y_obs = abs_conf_score.get_estimation_distribution(y_pred, conf_scores) - np.testing.assert_allclose(y_obs, y_toy) + y_obs = abs_conf_score.get_estimation_distribution( + X_toy_test, y_pred, conf_scores + ) + np.testing.assert_allclose(y_obs, y_toy_test) @pytest.mark.parametrize("y_pred", [np.array(y_pred_list), y_pred_list]) @@ -72,12 +96,12 @@ def test_absolute_conformity_score_consistency(y_pred: NDArray) -> None: """Test methods consistency for AbsoluteConformityScore.""" abs_conf_score = AbsoluteConformityScore() signed_conf_scores = abs_conf_score.get_signed_conformity_scores( - y_toy, y_pred + X_toy_test, y_toy_test, y_pred ) y_obs = abs_conf_score.get_estimation_distribution( - y_pred, signed_conf_scores + X_toy_test, y_pred, signed_conf_scores ) - np.testing.assert_allclose(y_obs, y_toy) + np.testing.assert_allclose(y_obs, y_toy_test) @pytest.mark.parametrize("y_pred", [np.array(y_pred_list), y_pred_list]) @@ -86,7 +110,8 @@ def test_gamma_conformity_score_get_conformity_scores( ) -> None: """Test conformity score computation for GammaConformityScore.""" gamma_conf_score = GammaConformityScore() - conf_scores = gamma_conf_score.get_conformity_scores(y_toy, y_pred) + conf_scores = gamma_conf_score.get_conformity_scores( + X_toy_test, y_toy_test, y_pred) expected_signed_conf_scores = np.array(conf_scores_gamma_list) np.testing.assert_allclose(conf_scores, expected_signed_conf_scores) @@ -104,8 +129,10 @@ def test_gamma_conformity_score_get_estimation_distribution( ) -> None: """Test conformity observed value computation for GammaConformityScore.""" # noqa: E501 gamma_conf_score = GammaConformityScore() - y_obs = gamma_conf_score.get_estimation_distribution(y_pred, conf_scores) - np.testing.assert_allclose(y_obs, y_toy) + y_obs = gamma_conf_score.get_estimation_distribution( + X_toy_test, y_pred, conf_scores + ) + np.testing.assert_allclose(y_obs, y_toy_test) @pytest.mark.parametrize("y_pred", [np.array(y_pred_list), y_pred_list]) @@ -113,12 +140,12 @@ def test_gamma_conformity_score_consistency(y_pred: NDArray) -> None: """Test methods consistency for GammaConformityScore.""" gamma_conf_score = GammaConformityScore() signed_conf_scores = gamma_conf_score.get_signed_conformity_scores( - y_toy, y_pred + X_toy_test, y_toy_test, y_pred ) y_obs = gamma_conf_score.get_estimation_distribution( - y_pred, signed_conf_scores + X_toy_test, y_pred, signed_conf_scores ) - np.testing.assert_allclose(y_obs, y_toy) + np.testing.assert_allclose(y_obs, y_toy_test) @pytest.mark.parametrize("y_pred", [np.array(y_pred_list), y_pred_list]) @@ -136,7 +163,9 @@ def test_gamma_conformity_score_check_oberved_value( """Test methods consistency for GammaConformityScore.""" gamma_conf_score = GammaConformityScore() with pytest.raises(ValueError): - gamma_conf_score.get_signed_conformity_scores(y_toy, y_pred) + gamma_conf_score.get_signed_conformity_scores( + [], y_toy, y_pred + ) @pytest.mark.parametrize( @@ -147,6 +176,14 @@ def test_gamma_conformity_score_check_oberved_value( [1, -7, 10, 12, 13, 12], ], ) +@pytest.mark.parametrize( + "X_toy", + [ + np.array([0, -7, 10, 12, 0, 12]).reshape(-1, 1), + np.array([0, 7, -10, 12, 1, -12]).reshape(-1, 1), + np.array([12, -7, 0, 12, 13, 2]).reshape(-1, 1), + ], +) @pytest.mark.parametrize( "conf_scores", [ @@ -155,7 +192,7 @@ def test_gamma_conformity_score_check_oberved_value( ], ) def test_gamma_conformity_score_check_predicted_value( - y_pred: NDArray, conf_scores: NDArray + y_pred: NDArray, conf_scores: NDArray, X_toy: NDArray ) -> None: """Test methods consistency for GammaConformityScore.""" gamma_conf_score = GammaConformityScore() @@ -163,27 +200,31 @@ def test_gamma_conformity_score_check_predicted_value( ValueError, match=r".*At least one of the predicted target is negative.*" ): - gamma_conf_score.get_signed_conformity_scores(y_toy, y_pred) + gamma_conf_score.get_signed_conformity_scores( + X_toy, y_toy_test, y_pred + ) with pytest.raises( ValueError, match=r".*At least one of the predicted target is negative.*" ): - gamma_conf_score.get_estimation_distribution(y_pred, conf_scores) + gamma_conf_score.get_estimation_distribution( + X_toy_test, y_pred, conf_scores + ) def test_check_consistency() -> None: """ - Test that a dummy ConformityScore class that gives inconsistent conformity - scores and distributions raises an error. + Test that a dummy ConformityScore class that gives inconsistent + conformityscores and distributions raises an error. """ dummy_conf_score = DummyConformityScore() conformity_scores = dummy_conf_score.get_signed_conformity_scores( - y_toy, y_pred_list + X_toy_test, y_toy_test, y_pred_list ) with pytest.raises( ValueError, match=r".*The two functions get_conformity_scores.*" ): dummy_conf_score.check_consistency( - y_toy, y_pred_list, conformity_scores + X_toy_test, y_toy_test, y_pred_list, conformity_scores ) diff --git a/mapie/tests/test_regression.py b/mapie/tests/test_regression.py index c089d0e2..810d11b2 100644 --- a/mapie/tests/test_regression.py +++ b/mapie/tests/test_regression.py @@ -18,12 +18,12 @@ from sklearn.utils.validation import check_is_fitted from typing_extensions import TypedDict -from mapie._typing import ArrayLike, NDArray +from mapie._typing import NDArray from mapie.aggregation_functions import aggregate_all from mapie.conformity_scores import (AbsoluteConformityScore, ConformityScore, GammaConformityScore) from mapie.metrics import regression_coverage_score -from mapie.regression import MapieRegressor +from mapie.regression import MapieRegressor, EnsembleRegressor from mapie.subsample import Subsample X_toy = np.array([0, 1, 2, 3, 4, 5]).reshape(-1, 1) @@ -173,8 +173,8 @@ def test_valid_estimator(strategy: str) -> None: estimator=DummyRegressor(), **STRATEGIES[strategy] ) mapie_reg.fit(X_toy, y_toy) - assert isinstance(mapie_reg.single_estimator_, DummyRegressor) - for estimator in mapie_reg.estimators_: + assert isinstance(mapie_reg.estimator_.single_estimator_, DummyRegressor) + for estimator in mapie_reg.estimator_.estimators_: assert isinstance(estimator, DummyRegressor) @@ -502,30 +502,38 @@ def test_aggregate_with_mask_with_prefit() -> None: """ Test ``_aggregate_with_mask`` in case ``cv`` is ``"prefit"``. """ - mapie_reg = MapieRegressor(cv="prefit") + mapie_reg = MapieRegressor(LinearRegression().fit(X, y), cv="prefit") + mapie_reg = mapie_reg.fit(X, y) with pytest.raises( ValueError, match=r".*There should not be aggregation of predictions if cv is*", ): - mapie_reg._aggregate_with_mask(k, k) + mapie_reg.estimator_._aggregate_with_mask(k, k) - mapie_reg = MapieRegressor(agg_function="nonsense") + +def test_aggregate_with_mask_with_invalid_agg_function() -> None: + """Test ``_aggregate_with_mask`` in case ``agg_function`` is invalid.""" + ens_reg = EnsembleRegressor(LinearRegression(), "plus", + KFold( + n_splits=5, random_state=None, shuffle=True + ), + "nonsense", None, random_state, 0.20, False + ) with pytest.raises( ValueError, match=r".*The value of self.agg_function is not correct*", ): - mapie_reg._aggregate_with_mask(k, k) + ens_reg._aggregate_with_mask(k, k) def test_pred_loof_isnan() -> None: """Test that if validation set is empty then prediction is empty.""" mapie_reg = MapieRegressor() - y_pred: ArrayLike - _, y_pred, _ = mapie_reg._fit_and_predict_oof_model( + mapie_reg = mapie_reg.fit(X, y) + y_pred: NDArray + y_pred, _ = mapie_reg.estimator_._predict_oof_estimator( estimator=LinearRegression(), X=X_toy, - y=y_toy, - train_index=[0, 1, 2, 3, 4], val_index=[], ) assert len(y_pred) == 0 diff --git a/mapie/tests/test_time_series_regression.py b/mapie/tests/test_time_series_regression.py index 1bdd5bee..110b423a 100644 --- a/mapie/tests/test_time_series_regression.py +++ b/mapie/tests/test_time_series_regression.py @@ -317,11 +317,10 @@ def test_invalid_aggregate_all() -> None: def test_pred_loof_isnan() -> None: """Test that if validation set is empty then prediction is empty.""" mapie_ts_reg = MapieTimeSeriesRegressor() - _, y_pred, _ = mapie_ts_reg._fit_and_predict_oof_model( - estimator=mapie_ts_reg, + mapie_ts_reg.fit(X_toy, y_toy) + y_pred, _ = mapie_ts_reg.estimator_._predict_oof_estimator( + estimator=mapie_ts_reg.estimator_.estimators_[0], X=X_toy, - y=y_toy, - train_index=[0, 1, 2, 3, 4], val_index=[], ) assert len(y_pred) == 0 From bff806710c3869e7b6e58f2295eeba0f5b23ce63 Mon Sep 17 00:00:00 2001 From: Candice Moyet Date: Thu, 6 Jul 2023 11:14:00 +0200 Subject: [PATCH 03/30] UPD: error in docstrings --- mapie/conformity_scores/conformity_scores.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/mapie/conformity_scores/conformity_scores.py b/mapie/conformity_scores/conformity_scores.py index 87976c76..f5b6877e 100644 --- a/mapie/conformity_scores/conformity_scores.py +++ b/mapie/conformity_scores/conformity_scores.py @@ -92,7 +92,7 @@ def get_estimation_distribution( values: ArrayLike ) -> NDArray: """ - Placeholder for ``get_signed_conformity_scores``. + Placeholder for ``get_estimation_distribution``. Subclasses should implement this method! Compute samples of the estimation distribution from the predicted @@ -105,8 +105,8 @@ def get_estimation_distribution( Observed feature values. y_pred: ArrayLike - Predicted values, it can be any type of predictions - (multi, low, up, ...). + Predicted reference values of shape (n_samples, ...). + The last dimension is the reference of the prediction. values: ArrayLike Either the conformity scores or the conformity scores aggregated @@ -114,7 +114,8 @@ def get_estimation_distribution( Returns ------- - ArrayLike + NDArray + Observed values. """ def check_consistency( From df65e8dd785f7903b971d876e6b4e747d32bb682 Mon Sep 17 00:00:00 2001 From: Candice Moyet Date: Thu, 6 Jul 2023 11:15:35 +0200 Subject: [PATCH 04/30] UPD: change method name --- mapie/conformity_scores/conformity_scores.py | 34 +++++++------------- 1 file changed, 11 insertions(+), 23 deletions(-) diff --git a/mapie/conformity_scores/conformity_scores.py b/mapie/conformity_scores/conformity_scores.py index f5b6877e..ffe29341 100644 --- a/mapie/conformity_scores/conformity_scores.py +++ b/mapie/conformity_scores/conformity_scores.py @@ -206,7 +206,7 @@ def get_conformity_scores( return conformity_scores @staticmethod - def _get_quantile( + def get_quantile( values: NDArray, alpha_np: NDArray, axis: int, @@ -298,32 +298,20 @@ def get_bounds( alpha_up = 1 - alpha_np if self.sym else 1 - alpha_np / 2 if method == "plus": - bound_low = self._get_quantile( - self.get_estimation_distribution( - X, y_pred_low, signed * conformity_scores - ), - alpha_low, - axis=1, - method="lower" - ) - bound_up = self._get_quantile( - self.get_estimation_distribution( - X, y_pred_up, conformity_scores - ), - alpha_up, - axis=1, - method="higher" - ) + bound_low = self.get_quantile(self.get_estimation_distribution( + X, y_pred_low, signed * conformity_scores + ), alpha_low, axis=1, method="lower") + bound_up = self.get_quantile(self.get_estimation_distribution( + X, y_pred_up, conformity_scores + ), alpha_up, axis=1, method="higher") else: quantile_search = "higher" if self.sym else "lower" alpha_low = 1 - alpha_np if self.sym else alpha_np / 2 - quantile_low = self._get_quantile( - conformity_scores, alpha_low, axis=0, method=quantile_search - ) - quantile_up = self._get_quantile( - conformity_scores, alpha_up, axis=0, method="higher" - ) + quantile_low = self.get_quantile(conformity_scores, alpha_low, + axis=0, method=quantile_search) + quantile_up = self.get_quantile(conformity_scores, alpha_up, + axis=0, method="higher") bound_low = self.get_estimation_distribution( X, y_pred_low, signed * quantile_low ) From 596c123867bbd42d3081509f3147596bd0b4725f Mon Sep 17 00:00:00 2001 From: Candice Moyet Date: Thu, 6 Jul 2023 11:28:23 +0200 Subject: [PATCH 05/30] UPD: ordering imports --- mapie/conformity_scores/conformity_scores.py | 2 +- mapie/conformity_scores/residual_conformity_scores.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/mapie/conformity_scores/conformity_scores.py b/mapie/conformity_scores/conformity_scores.py index ffe29341..57fdf117 100644 --- a/mapie/conformity_scores/conformity_scores.py +++ b/mapie/conformity_scores/conformity_scores.py @@ -2,10 +2,10 @@ import numpy as np from typing import Tuple +from sklearn.base import RegressorMixin from mapie._compatibility import np_nanquantile from mapie._typing import ArrayLike, NDArray -from sklearn.base import RegressorMixin class ConformityScore(metaclass=ABCMeta): diff --git a/mapie/conformity_scores/residual_conformity_scores.py b/mapie/conformity_scores/residual_conformity_scores.py index 6d317d29..36ed74a0 100644 --- a/mapie/conformity_scores/residual_conformity_scores.py +++ b/mapie/conformity_scores/residual_conformity_scores.py @@ -2,6 +2,7 @@ from mapie._machine_precision import EPSILON from mapie._typing import ArrayLike, NDArray + from mapie.conformity_scores import ConformityScore From 4b403408fb43c020799d58c2f6a176ba5f705106 Mon Sep 17 00:00:00 2001 From: Candice Moyet Date: Thu, 6 Jul 2023 11:36:14 +0200 Subject: [PATCH 06/30] UPD: improve code readability --- mapie/conformity_scores/conformity_scores.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/mapie/conformity_scores/conformity_scores.py b/mapie/conformity_scores/conformity_scores.py index 57fdf117..1f5a0907 100644 --- a/mapie/conformity_scores/conformity_scores.py +++ b/mapie/conformity_scores/conformity_scores.py @@ -298,12 +298,18 @@ def get_bounds( alpha_up = 1 - alpha_np if self.sym else 1 - alpha_np / 2 if method == "plus": - bound_low = self.get_quantile(self.get_estimation_distribution( + values_low = self.get_estimation_distribution( X, y_pred_low, signed * conformity_scores - ), alpha_low, axis=1, method="lower") - bound_up = self.get_quantile(self.get_estimation_distribution( + ) + values_up = self.get_estimation_distribution( X, y_pred_up, conformity_scores - ), alpha_up, axis=1, method="higher") + ) + bound_low = self.get_quantile( + values_low, alpha_low, axis=1, method="lower" + ) + bound_up = self.get_quantile( + values_up, alpha_up, axis=1, method="higher" + ) else: quantile_search = "higher" if self.sym else "lower" alpha_low = 1 - alpha_np if self.sym else alpha_np / 2 From 8a3b3e7c7db908bac80bf44e11fb3b91062b3ef2 Mon Sep 17 00:00:00 2001 From: Candice Moyet Date: Thu, 6 Jul 2023 12:05:51 +0200 Subject: [PATCH 07/30] UPD: improve code readability --- mapie/conformity_scores/conformity_scores.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/mapie/conformity_scores/conformity_scores.py b/mapie/conformity_scores/conformity_scores.py index 1f5a0907..275f3372 100644 --- a/mapie/conformity_scores/conformity_scores.py +++ b/mapie/conformity_scores/conformity_scores.py @@ -294,10 +294,11 @@ def get_bounds( y_pred, y_pred_low, y_pred_up = estimator.predict(X, ensemble) signed = -1 if self.sym else 1 - alpha_low = alpha_np if self.sym else alpha_np / 2 - alpha_up = 1 - alpha_np if self.sym else 1 - alpha_np / 2 if method == "plus": + alpha_low = alpha_np if self.sym else alpha_np / 2 + alpha_up = 1 - alpha_np if self.sym else 1 - alpha_np / 2 + values_low = self.get_estimation_distribution( X, y_pred_low, signed * conformity_scores ) @@ -313,6 +314,7 @@ def get_bounds( else: quantile_search = "higher" if self.sym else "lower" alpha_low = 1 - alpha_np if self.sym else alpha_np / 2 + alpha_up = 1 - alpha_np if self.sym else 1 - alpha_np / 2 quantile_low = self.get_quantile(conformity_scores, alpha_low, axis=0, method=quantile_search) From 73d75aae77687ef3eee9438c28303a38ee126c55 Mon Sep 17 00:00:00 2001 From: Candice Moyet Date: Thu, 6 Jul 2023 12:29:14 +0200 Subject: [PATCH 08/30] UPD: beautify --- mapie/conformity_scores/conformity_scores.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/mapie/conformity_scores/conformity_scores.py b/mapie/conformity_scores/conformity_scores.py index 275f3372..7a033906 100644 --- a/mapie/conformity_scores/conformity_scores.py +++ b/mapie/conformity_scores/conformity_scores.py @@ -292,7 +292,6 @@ def get_bounds( - The upper bounds of the prediction intervals. """ y_pred, y_pred_low, y_pred_up = estimator.predict(X, ensemble) - signed = -1 if self.sym else 1 if method == "plus": @@ -316,10 +315,12 @@ def get_bounds( alpha_low = 1 - alpha_np if self.sym else alpha_np / 2 alpha_up = 1 - alpha_np if self.sym else 1 - alpha_np / 2 - quantile_low = self.get_quantile(conformity_scores, alpha_low, - axis=0, method=quantile_search) - quantile_up = self.get_quantile(conformity_scores, alpha_up, - axis=0, method="higher") + quantile_low = self.get_quantile( + conformity_scores, alpha_low, axis=0, method=quantile_search + ) + quantile_up = self.get_quantile( + conformity_scores, alpha_up, axis=0, method="higher" + ) bound_low = self.get_estimation_distribution( X, y_pred_low, signed * quantile_low ) From 24c5f354cdcd25824a9ce694ce6bbf147846ed8b Mon Sep 17 00:00:00 2001 From: Candice Moyet Date: Thu, 6 Jul 2023 12:43:02 +0200 Subject: [PATCH 09/30] UPD: fix accessibility --- mapie/regression/__init__.py | 4 +--- mapie/tests/test_conformity_scores.py | 2 +- mapie/tests/test_regression.py | 3 ++- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/mapie/regression/__init__.py b/mapie/regression/__init__.py index 7360f618..16243ace 100644 --- a/mapie/regression/__init__.py +++ b/mapie/regression/__init__.py @@ -1,11 +1,9 @@ from .regression import MapieRegressor from .quantile_regression import MapieQuantileRegressor from .time_series_regression import MapieTimeSeriesRegressor -from .estimator import EnsembleRegressor __all__ = [ "MapieRegressor", "MapieQuantileRegressor", - "MapieTimeSeriesRegressor", - "EnsembleRegressor" + "MapieTimeSeriesRegressor" ] diff --git a/mapie/tests/test_conformity_scores.py b/mapie/tests/test_conformity_scores.py index ce27e227..e9c8c58b 100644 --- a/mapie/tests/test_conformity_scores.py +++ b/mapie/tests/test_conformity_scores.py @@ -4,7 +4,7 @@ from mapie._typing import NDArray, ArrayLike from mapie.conformity_scores import (AbsoluteConformityScore, ConformityScore, GammaConformityScore) -from mapie.regression import EnsembleRegressor +from mapie.regression.estimator import EnsembleRegressor from sklearn.linear_model import LinearRegression from sklearn.model_selection import KFold diff --git a/mapie/tests/test_regression.py b/mapie/tests/test_regression.py index 810d11b2..2fcddfdb 100644 --- a/mapie/tests/test_regression.py +++ b/mapie/tests/test_regression.py @@ -23,7 +23,8 @@ from mapie.conformity_scores import (AbsoluteConformityScore, ConformityScore, GammaConformityScore) from mapie.metrics import regression_coverage_score -from mapie.regression import MapieRegressor, EnsembleRegressor +from mapie.regression import MapieRegressor +from mapie.regression.estimator import EnsembleRegressor from mapie.subsample import Subsample X_toy = np.array([0, 1, 2, 3, 4, 5]).reshape(-1, 1) From 9058824b1ffd4423d4f26a69be8194293d60ee2d Mon Sep 17 00:00:00 2001 From: Candice Moyet Date: Thu, 6 Jul 2023 14:53:24 +0200 Subject: [PATCH 10/30] UPD: update docstrings --- mapie/conformity_scores/conformity_scores.py | 19 +++++++++---------- .../residual_conformity_scores.py | 12 ++++++------ mapie/regression/estimator.py | 1 - mapie/tests/test_conformity_scores.py | 2 +- 4 files changed, 16 insertions(+), 18 deletions(-) diff --git a/mapie/conformity_scores/conformity_scores.py b/mapie/conformity_scores/conformity_scores.py index 7a033906..07a38d50 100644 --- a/mapie/conformity_scores/conformity_scores.py +++ b/mapie/conformity_scores/conformity_scores.py @@ -96,8 +96,8 @@ def get_estimation_distribution( Subclasses should implement this method! Compute samples of the estimation distribution from the predicted - targets and ``values``that can be either the conformity scores or - the conformity scores aggregated with the predictions. + targets and ``values`` that can be either the conformity scores or + the quantile of the conformity scores. Parameters ---------- @@ -109,8 +109,8 @@ def get_estimation_distribution( The last dimension is the reference of the prediction. values: ArrayLike - Either the conformity scores or the conformity scores aggregated - with the predictions according to the subclass formula. + Either the conformity scores or the quantile of the conformity + scores aggregated. Returns ------- @@ -213,8 +213,8 @@ def get_quantile( method: str ) -> NDArray: """ - Compute the alpha quantile of the conformity scores considering - the symmetrical property if so. + Compute the alpha quantile of the conformity scores or the conformity + scores aggregated with the predictions. Parameters ---------- @@ -236,8 +236,6 @@ def get_quantile( Returns ------- NDArray - Lower and upper quantile of the prediction intervals. - These quantiles are identical if the score is not symmetrical. """ quantile = np.column_stack([ np_nanquantile( @@ -261,7 +259,7 @@ def get_bounds( ) -> Tuple[NDArray, NDArray, NDArray]: """ Compute bounds of the prediction intervals from the observed values, - the estimator of MapieRegressor and the conformity scores. + the estimator of ``EnsembleRegressor`` and the conformity scores. Parameters ---------- @@ -282,7 +280,8 @@ def get_bounds( Boolean determining whether the predictions are ensembled or not. method: str - The method parameter of MapieRegressor. + The method parameter of MapieRegressor : ``"base"``, ``"minmax"`` + or ``"plus"``. Returns ------- diff --git a/mapie/conformity_scores/residual_conformity_scores.py b/mapie/conformity_scores/residual_conformity_scores.py index 36ed74a0..59a872ee 100644 --- a/mapie/conformity_scores/residual_conformity_scores.py +++ b/mapie/conformity_scores/residual_conformity_scores.py @@ -30,7 +30,7 @@ def get_signed_conformity_scores( ) -> NDArray: """ Compute the signed conformity scores from the observed values - and the estimator, from the following formula: + and the predicted ones, from the following formula: signed conformity score = y - y_pred """ return np.subtract(y, y_pred) @@ -47,8 +47,8 @@ def get_estimation_distribution( signed conformity score = y - y_pred <=> y = y_pred + signed conformity score - ``values`` can be either the conformity scores or - the conformity scores aggregated with the predictions. + ``values`` can be either the conformity scores or the quantile of + the conformity scores. """ return np.add(y_pred, values) @@ -108,7 +108,7 @@ def get_signed_conformity_scores( ) -> NDArray: """ Compute the signed conformity scores from the observed values - and the estimator, from the following formula: + and the predicted ones, from the following formula: signed conformity score = (y - y_pred) / y_pred """ self._check_observed_data(y) @@ -127,8 +127,8 @@ def get_estimation_distribution( signed conformity score = (y - y_pred) / y_pred <=> y = y_pred * (1 + signed conformity score) - ``values`` can be either the conformity scores or - the conformity scores aggregated with the predictions. + ``values`` can be either the conformity scores or the quantile of + the conformity scores. """ self._check_predicted_data(y_pred) return np.multiply(y_pred, np.add(1, values)) diff --git a/mapie/regression/estimator.py b/mapie/regression/estimator.py index 7db6aba3..d4a051e9 100644 --- a/mapie/regression/estimator.py +++ b/mapie/regression/estimator.py @@ -252,7 +252,6 @@ def fit( ------- EnsembleRegressor The estimator fitted. - """ # Initialization single_estimator_: RegressorMixin diff --git a/mapie/tests/test_conformity_scores.py b/mapie/tests/test_conformity_scores.py index e9c8c58b..11e25bd5 100644 --- a/mapie/tests/test_conformity_scores.py +++ b/mapie/tests/test_conformity_scores.py @@ -215,7 +215,7 @@ def test_gamma_conformity_score_check_predicted_value( def test_check_consistency() -> None: """ Test that a dummy ConformityScore class that gives inconsistent - conformityscores and distributions raises an error. + conformity scores and distributions raises an error. """ dummy_conf_score = DummyConformityScore() conformity_scores = dummy_conf_score.get_signed_conformity_scores( From 4540f2e12358bdf76ea4c4a4a91bfd3cc721ada2 Mon Sep 17 00:00:00 2001 From: Candice Moyet Date: Thu, 6 Jul 2023 15:09:00 +0200 Subject: [PATCH 11/30] UPD: update parameter name --- mapie/conformity_scores/conformity_scores.py | 13 ++++---- .../residual_conformity_scores.py | 22 ++++++------- mapie/tests/test_conformity_scores.py | 32 ++++++++----------- 3 files changed, 30 insertions(+), 37 deletions(-) diff --git a/mapie/conformity_scores/conformity_scores.py b/mapie/conformity_scores/conformity_scores.py index 07a38d50..440a3e04 100644 --- a/mapie/conformity_scores/conformity_scores.py +++ b/mapie/conformity_scores/conformity_scores.py @@ -89,15 +89,15 @@ def get_estimation_distribution( self, X: ArrayLike, y_pred: ArrayLike, - values: ArrayLike + conformity_scores: ArrayLike ) -> NDArray: """ Placeholder for ``get_estimation_distribution``. Subclasses should implement this method! Compute samples of the estimation distribution from the predicted - targets and ``values`` that can be either the conformity scores or - the quantile of the conformity scores. + targets and ``conformity_scores`` that can be either the conformity + scores or the quantile of the conformity scores. Parameters ---------- @@ -108,7 +108,7 @@ def get_estimation_distribution( Predicted reference values of shape (n_samples, ...). The last dimension is the reference of the prediction. - values: ArrayLike + conformity_scores: ArrayLike Either the conformity scores or the quantile of the conformity scores aggregated. @@ -155,9 +155,8 @@ def check_consistency( ValueError If the two methods are not consistent. """ - score_distribution = self.get_estimation_distribution( - X, y_pred, conformity_scores - ) + score_distribution = self.get_estimation_distribution(X, y_pred, + conformity_scores) abs_conformity_scores = np.abs(np.subtract(score_distribution, y)) max_conf_score = np.max(abs_conformity_scores) if max_conf_score > self.eps: diff --git a/mapie/conformity_scores/residual_conformity_scores.py b/mapie/conformity_scores/residual_conformity_scores.py index 59a872ee..92772f32 100644 --- a/mapie/conformity_scores/residual_conformity_scores.py +++ b/mapie/conformity_scores/residual_conformity_scores.py @@ -39,18 +39,18 @@ def get_estimation_distribution( self, X: ArrayLike, y_pred: ArrayLike, - values: ArrayLike - ): + conformity_scores: ArrayLike + ) -> NDArray: """ Compute samples of the estimation distribution from the predicted - targets and ``values``, from the following formula: + targets and ``conformity_scores``, from the following formula: signed conformity score = y - y_pred <=> y = y_pred + signed conformity score - ``values`` can be either the conformity scores or the quantile of - the conformity scores. + ``conformity_scores`` can be either the conformity scores or + the quantile of the conformity scores. """ - return np.add(y_pred, values) + return np.add(y_pred, conformity_scores) class GammaConformityScore(ConformityScore): @@ -119,16 +119,16 @@ def get_estimation_distribution( self, X: ArrayLike, y_pred: ArrayLike, - values: ArrayLike, + conformity_scores: ArrayLike ) -> NDArray: """ Compute samples of the estimation distribution from the predicted - targets and ``values``, from the following formula: + targets and ``conformity_scores``, from the following formula: signed conformity score = (y - y_pred) / y_pred <=> y = y_pred * (1 + signed conformity score) - ``values`` can be either the conformity scores or the quantile of - the conformity scores. + ``conformity_scores`` can be either the conformity scores or + the quantile of the conformity scores. """ self._check_predicted_data(y_pred) - return np.multiply(y_pred, np.add(1, values)) + return np.multiply(y_pred, np.add(1, conformity_scores)) diff --git a/mapie/tests/test_conformity_scores.py b/mapie/tests/test_conformity_scores.py index 11e25bd5..3d8827f5 100644 --- a/mapie/tests/test_conformity_scores.py +++ b/mapie/tests/test_conformity_scores.py @@ -28,15 +28,14 @@ def get_signed_conformity_scores( ) -> NDArray: return np.subtract(y, y_pred) - def get_estimation_distribution( - self, X: ArrayLike, y_pred: ArrayLike, values: ArrayLike - ) -> NDArray: + def get_estimation_distribution(self, X: ArrayLike, y_pred: ArrayLike, + conformity_scores: ArrayLike) -> NDArray: """ A positive constant is added to the sum between predictions and conformity scores to make the estimated distribution inconsistent with the conformity score. """ - return np.add(y_pred, values) + 1 + return np.add(y_pred, conformity_scores) + 1 estimator_toy = EnsembleRegressor( @@ -85,9 +84,8 @@ def test_absolute_conformity_score_get_estimation_distribution( ) -> None: """Test conformity observed value computation for AbsoluteConformityScore.""" # noqa: E501 abs_conf_score = AbsoluteConformityScore() - y_obs = abs_conf_score.get_estimation_distribution( - X_toy_test, y_pred, conf_scores - ) + y_obs = abs_conf_score.get_estimation_distribution(X_toy_test, y_pred, + conf_scores) np.testing.assert_allclose(y_obs, y_toy_test) @@ -98,9 +96,8 @@ def test_absolute_conformity_score_consistency(y_pred: NDArray) -> None: signed_conf_scores = abs_conf_score.get_signed_conformity_scores( X_toy_test, y_toy_test, y_pred ) - y_obs = abs_conf_score.get_estimation_distribution( - X_toy_test, y_pred, signed_conf_scores - ) + y_obs = abs_conf_score.get_estimation_distribution(X_toy_test, y_pred, + signed_conf_scores) np.testing.assert_allclose(y_obs, y_toy_test) @@ -129,9 +126,8 @@ def test_gamma_conformity_score_get_estimation_distribution( ) -> None: """Test conformity observed value computation for GammaConformityScore.""" # noqa: E501 gamma_conf_score = GammaConformityScore() - y_obs = gamma_conf_score.get_estimation_distribution( - X_toy_test, y_pred, conf_scores - ) + y_obs = gamma_conf_score.get_estimation_distribution(X_toy_test, y_pred, + conf_scores) np.testing.assert_allclose(y_obs, y_toy_test) @@ -142,9 +138,8 @@ def test_gamma_conformity_score_consistency(y_pred: NDArray) -> None: signed_conf_scores = gamma_conf_score.get_signed_conformity_scores( X_toy_test, y_toy_test, y_pred ) - y_obs = gamma_conf_score.get_estimation_distribution( - X_toy_test, y_pred, signed_conf_scores - ) + y_obs = gamma_conf_score.get_estimation_distribution(X_toy_test, y_pred, + signed_conf_scores) np.testing.assert_allclose(y_obs, y_toy_test) @@ -207,9 +202,8 @@ def test_gamma_conformity_score_check_predicted_value( ValueError, match=r".*At least one of the predicted target is negative.*" ): - gamma_conf_score.get_estimation_distribution( - X_toy_test, y_pred, conf_scores - ) + gamma_conf_score.get_estimation_distribution(X_toy_test, y_pred, + conf_scores) def test_check_consistency() -> None: From bf84501876efee04092c6d63cf2307171ace32d6 Mon Sep 17 00:00:00 2001 From: Candice Moyet Date: Thu, 6 Jul 2023 15:33:21 +0200 Subject: [PATCH 12/30] UPD: docstrings --- mapie/conformity_scores/conformity_scores.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/mapie/conformity_scores/conformity_scores.py b/mapie/conformity_scores/conformity_scores.py index 440a3e04..de86fe92 100644 --- a/mapie/conformity_scores/conformity_scores.py +++ b/mapie/conformity_scores/conformity_scores.py @@ -235,6 +235,7 @@ def get_quantile( Returns ------- NDArray + The quantile of the conformity scores. """ quantile = np.column_stack([ np_nanquantile( @@ -258,7 +259,7 @@ def get_bounds( ) -> Tuple[NDArray, NDArray, NDArray]: """ Compute bounds of the prediction intervals from the observed values, - the estimator of ``EnsembleRegressor`` and the conformity scores. + the estimator of type ``EnsembleRegressor`` and the conformity scores. Parameters ---------- @@ -279,8 +280,11 @@ def get_bounds( Boolean determining whether the predictions are ensembled or not. method: str - The method parameter of MapieRegressor : ``"base"``, ``"minmax"`` - or ``"plus"``. + Method to choose for prediction interval estimates. + The ``"plus"`` method implies that the quantile is calculated + after estimating the bounds, whereas the other methods + (among the ``"naive"``, ``"base"`` or ``"minmax"`` methods, + for example) do the opposite. Returns ------- From 7d83f8e9d1dc5a492697f3bf0b59d02c317ce8a7 Mon Sep 17 00:00:00 2001 From: Candice Moyet Date: Thu, 6 Jul 2023 15:40:24 +0200 Subject: [PATCH 13/30] UPD: improve code readability --- mapie/regression/estimator.py | 178 +++++++++++++------------- mapie/regression/regression.py | 4 +- mapie/tests/test_conformity_scores.py | 2 +- mapie/tests/test_regression.py | 16 ++- 4 files changed, 102 insertions(+), 98 deletions(-) diff --git a/mapie/regression/estimator.py b/mapie/regression/estimator.py index d4a051e9..cc14d11f 100644 --- a/mapie/regression/estimator.py +++ b/mapie/regression/estimator.py @@ -223,76 +223,37 @@ def _fit_oof_estimator( ) return estimator - def fit( - self, + @staticmethod + def _predict_oof_estimator( + estimator: RegressorMixin, X: ArrayLike, - y: ArrayLike, - sample_weight: Optional[ArrayLike] = None, - ) -> EnsembleRegressor: + val_index: ArrayLike, + ): """ - Fit the base estimator under the ``single_estimator_`` attribute. - Fit all cross-validated estimator clones - and rearrange them into a list, the ``estimators_`` attribute. - Out-of-fold conformity scores are stored under - the ``conformity_scores_`` attribute. + Perform predictions on a single out-of-fold model on a validation set. Parameters ---------- + estimator: RegressorMixin + Estimator to train. + X: ArrayLike of shape (n_samples, n_features) Input data. - y: ArrayLike of shape (n_samples,) - Input labels. - - sample_weight: Optional[ArrayLike] of shape (n_samples,) - Sample weights. If None, then samples are equally weighted. - By default ``None``. + val_index: ArrayLike of shape (n_samples_val) + Validation data indices. Returns ------- - EnsembleRegressor - The estimator fitted. + Tuple[NDArray, ArrayLike] + Predictions of estimator from val_index of X. """ - # Initialization - single_estimator_: RegressorMixin - estimators_: List[RegressorMixin] = [] - full_indexes = np.arange(_num_samples(X)) - cv = self.cv - estimator = self.estimator - n_samples = _num_samples(y) - - # Computation - if cv == "prefit": - single_estimator_ = estimator - self.k_ = np.full( - shape=(n_samples, 1), fill_value=np.nan, dtype=float - ) + X_val = _safe_indexing(X, val_index) + if _num_samples(X_val) > 0: + y_pred = estimator.predict(X_val) else: - single_estimator_ = self._fit_oof_estimator( - clone(estimator), X, y, full_indexes, sample_weight - ) - cv = cast(BaseCrossValidator, cv) - self.k_ = np.full( - shape=(n_samples, cv.get_n_splits(X, y)), - fill_value=np.nan, - dtype=float, - ) - if self.method == "naive": - estimators_ = [single_estimator_] - else: - estimators_ = Parallel(self.n_jobs, verbose=self.verbose)( - delayed(self._fit_oof_estimator)( - clone(estimator), X, y, train_index, sample_weight - ) - for train_index, _ in cv.split(X) - ) - if isinstance(cv, ShuffleSplit): - single_estimator_ = estimators_[0] - - self.single_estimator_ = single_estimator_ - self.estimators_ = estimators_ - - return self + y_pred = np.array([]) + return y_pred, val_index def _aggregate_with_mask( self, @@ -340,38 +301,6 @@ def _aggregate_with_mask( else: raise ValueError("The value of self.agg_function is not correct") - @staticmethod - def _predict_oof_estimator( - estimator: RegressorMixin, - X: ArrayLike, - val_index: ArrayLike, - ): - """ - Perform predictions on a single out-of-fold model on a validation set. - - Parameters - ---------- - estimator: RegressorMixin - Estimator to train. - - X: ArrayLike of shape (n_samples, n_features) - Input data. - - val_index: ArrayLike of shape (n_samples_val) - Validation data indices. - - Returns - ------- - Tuple[NDArray, ArrayLike] - Predictions of estimator from val_index of X. - """ - X_val = _safe_indexing(X, val_index) - if _num_samples(X_val) > 0: - y_pred = estimator.predict(X_val) - else: - y_pred = np.array([]) - return y_pred, val_index - def _pred_multi(self, X: ArrayLike) -> NDArray: """ Return a prediction per train sample for each test sample, by @@ -447,6 +376,77 @@ def predict_calib(self, X: ArrayLike) -> NDArray: return y_pred + def fit( + self, + X: ArrayLike, + y: ArrayLike, + sample_weight: Optional[ArrayLike] = None, + ) -> EnsembleRegressor: + """ + Fit the base estimator under the ``single_estimator_`` attribute. + Fit all cross-validated estimator clones + and rearrange them into a list, the ``estimators_`` attribute. + Out-of-fold conformity scores are stored under + the ``conformity_scores_`` attribute. + + Parameters + ---------- + X: ArrayLike of shape (n_samples, n_features) + Input data. + + y: ArrayLike of shape (n_samples,) + Input labels. + + sample_weight: Optional[ArrayLike] of shape (n_samples,) + Sample weights. If None, then samples are equally weighted. + By default ``None``. + + Returns + ------- + EnsembleRegressor + The estimator fitted. + """ + # Initialization + single_estimator_: RegressorMixin + estimators_: List[RegressorMixin] = [] + full_indexes = np.arange(_num_samples(X)) + cv = self.cv + estimator = self.estimator + n_samples = _num_samples(y) + + # Computation + if cv == "prefit": + single_estimator_ = estimator + self.k_ = np.full( + shape=(n_samples, 1), fill_value=np.nan, dtype=float + ) + else: + single_estimator_ = self._fit_oof_estimator( + clone(estimator), X, y, full_indexes, sample_weight + ) + cv = cast(BaseCrossValidator, cv) + self.k_ = np.full( + shape=(n_samples, cv.get_n_splits(X, y)), + fill_value=np.nan, + dtype=float, + ) + if self.method == "naive": + estimators_ = [single_estimator_] + else: + estimators_ = Parallel(self.n_jobs, verbose=self.verbose)( + delayed(self._fit_oof_estimator)( + clone(estimator), X, y, train_index, sample_weight + ) + for train_index, _ in cv.split(X) + ) + if isinstance(cv, ShuffleSplit): + single_estimator_ = estimators_[0] + + self.single_estimator_ = single_estimator_ + self.estimators_ = estimators_ + + return self + def predict( self, X: ArrayLike, diff --git a/mapie/regression/regression.py b/mapie/regression/regression.py index 5e007255..087a1fa2 100644 --- a/mapie/regression/regression.py +++ b/mapie/regression/regression.py @@ -545,7 +545,7 @@ def predict( alpha_np = cast(NDArray, alpha) check_alpha_and_n_samples(alpha_np, n) - y_pred, bound_low, bound_up = \ + y_pred, y_pred_low, y_pred_up = \ self.conformity_score_function_.get_bounds( X, self.estimator_, @@ -554,7 +554,7 @@ def predict( ensemble, self.method ) - return y_pred, np.stack([bound_low, bound_up], axis=1) + return y_pred, np.stack([y_pred_low, y_pred_up], axis=1) else: y_pred, _, _ = self.estimator_.predict(X, ensemble) diff --git a/mapie/tests/test_conformity_scores.py b/mapie/tests/test_conformity_scores.py index 3d8827f5..c78b0d57 100644 --- a/mapie/tests/test_conformity_scores.py +++ b/mapie/tests/test_conformity_scores.py @@ -1,7 +1,7 @@ import numpy as np import pytest -from mapie._typing import NDArray, ArrayLike +from mapie._typing import ArrayLike, NDArray from mapie.conformity_scores import (AbsoluteConformityScore, ConformityScore, GammaConformityScore) from mapie.regression.estimator import EnsembleRegressor diff --git a/mapie/tests/test_regression.py b/mapie/tests/test_regression.py index 2fcddfdb..2d62d94f 100644 --- a/mapie/tests/test_regression.py +++ b/mapie/tests/test_regression.py @@ -514,12 +514,16 @@ def test_aggregate_with_mask_with_prefit() -> None: def test_aggregate_with_mask_with_invalid_agg_function() -> None: """Test ``_aggregate_with_mask`` in case ``agg_function`` is invalid.""" - ens_reg = EnsembleRegressor(LinearRegression(), "plus", - KFold( - n_splits=5, random_state=None, shuffle=True - ), - "nonsense", None, random_state, 0.20, False - ) + ens_reg = EnsembleRegressor( + LinearRegression(), + "plus", + KFold(n_splits=5, random_state=None, shuffle=True), + "nonsense", + None, + random_state, + 0.20, + False + ) with pytest.raises( ValueError, match=r".*The value of self.agg_function is not correct*", From efc597a0e74e7c87270694c959d37a72f03a5efa Mon Sep 17 00:00:00 2001 From: Candice Moyet Date: Thu, 6 Jul 2023 15:59:31 +0200 Subject: [PATCH 14/30] UPD: delete changes of toyset in unit tests --- mapie/tests/test_conformity_scores.py | 48 ++++++++++++--------------- 1 file changed, 22 insertions(+), 26 deletions(-) diff --git a/mapie/tests/test_conformity_scores.py b/mapie/tests/test_conformity_scores.py index c78b0d57..66dc896f 100644 --- a/mapie/tests/test_conformity_scores.py +++ b/mapie/tests/test_conformity_scores.py @@ -8,14 +8,11 @@ from sklearn.linear_model import LinearRegression from sklearn.model_selection import KFold -X_toy_train = np.array([0, 1, 2, 3, 4, 5]).reshape(-1, 1) -y_toy_train = np.array([5, 7, 9, 11, 13, 15]) -X_toy_test = np.array([6, 9, 10, 2, 4, 5]).reshape(-1, 1) -y_toy_test = np.array([15, 4, 90, 1, 15, 1]) -y_pred_list = [17., 23., 25., 9., 13., 15.] -conf_scores_list = [-2., -19., 65., -8., 2., -14.] -conf_scores_gamma_list = [-0.11764706, -0.82608696, 2.6, - -0.88888889, 0.15384615, -0.93333333] +X_toy = np.array([0, 1, 2, 3, 4, 5]).reshape(-1, 1) +y_toy = np.array([5, 7, 9, 11, 13, 15]) +y_pred_list = [4, 7, 10, 12, 13, 12] +conf_scores_list = [1, 0, -1, -1, 0, 3] +conf_scores_gamma_list = [1 / 4, 0, -1 / 10, -1 / 12, 0, 3 / 12] random_state = 42 @@ -48,7 +45,6 @@ def get_estimation_distribution(self, X: ArrayLike, y_pred: ArrayLike, 0.20, False ) -estimator_toy_fitted = estimator_toy.fit(X_toy_train, y_toy_train) @pytest.mark.parametrize("sym", [False, True]) @@ -64,10 +60,10 @@ def test_absolute_conformity_score_get_conformity_scores( """Test conformity score computation for AbsoluteConformityScore.""" abs_conf_score = AbsoluteConformityScore() signed_conf_scores = abs_conf_score.get_signed_conformity_scores( - X_toy_test, y_toy_test, y_pred + X_toy, y_toy, y_pred ) conf_scores = abs_conf_score.get_conformity_scores( - X_toy_test, y_toy_test, y_pred + X_toy, y_toy, y_pred ) expected_signed_conf_scores = np.array(conf_scores_list) expected_conf_scores = np.abs(expected_signed_conf_scores) @@ -84,9 +80,9 @@ def test_absolute_conformity_score_get_estimation_distribution( ) -> None: """Test conformity observed value computation for AbsoluteConformityScore.""" # noqa: E501 abs_conf_score = AbsoluteConformityScore() - y_obs = abs_conf_score.get_estimation_distribution(X_toy_test, y_pred, + y_obs = abs_conf_score.get_estimation_distribution(X_toy, y_pred, conf_scores) - np.testing.assert_allclose(y_obs, y_toy_test) + np.testing.assert_allclose(y_obs, y_toy) @pytest.mark.parametrize("y_pred", [np.array(y_pred_list), y_pred_list]) @@ -94,11 +90,11 @@ def test_absolute_conformity_score_consistency(y_pred: NDArray) -> None: """Test methods consistency for AbsoluteConformityScore.""" abs_conf_score = AbsoluteConformityScore() signed_conf_scores = abs_conf_score.get_signed_conformity_scores( - X_toy_test, y_toy_test, y_pred + X_toy, y_toy, y_pred ) - y_obs = abs_conf_score.get_estimation_distribution(X_toy_test, y_pred, + y_obs = abs_conf_score.get_estimation_distribution(X_toy, y_pred, signed_conf_scores) - np.testing.assert_allclose(y_obs, y_toy_test) + np.testing.assert_allclose(y_obs, y_toy) @pytest.mark.parametrize("y_pred", [np.array(y_pred_list), y_pred_list]) @@ -108,7 +104,7 @@ def test_gamma_conformity_score_get_conformity_scores( """Test conformity score computation for GammaConformityScore.""" gamma_conf_score = GammaConformityScore() conf_scores = gamma_conf_score.get_conformity_scores( - X_toy_test, y_toy_test, y_pred) + X_toy, y_toy, y_pred) expected_signed_conf_scores = np.array(conf_scores_gamma_list) np.testing.assert_allclose(conf_scores, expected_signed_conf_scores) @@ -126,9 +122,9 @@ def test_gamma_conformity_score_get_estimation_distribution( ) -> None: """Test conformity observed value computation for GammaConformityScore.""" # noqa: E501 gamma_conf_score = GammaConformityScore() - y_obs = gamma_conf_score.get_estimation_distribution(X_toy_test, y_pred, + y_obs = gamma_conf_score.get_estimation_distribution(X_toy, y_pred, conf_scores) - np.testing.assert_allclose(y_obs, y_toy_test) + np.testing.assert_allclose(y_obs, y_toy) @pytest.mark.parametrize("y_pred", [np.array(y_pred_list), y_pred_list]) @@ -136,11 +132,11 @@ def test_gamma_conformity_score_consistency(y_pred: NDArray) -> None: """Test methods consistency for GammaConformityScore.""" gamma_conf_score = GammaConformityScore() signed_conf_scores = gamma_conf_score.get_signed_conformity_scores( - X_toy_test, y_toy_test, y_pred + X_toy, y_toy, y_pred ) - y_obs = gamma_conf_score.get_estimation_distribution(X_toy_test, y_pred, + y_obs = gamma_conf_score.get_estimation_distribution(X_toy, y_pred, signed_conf_scores) - np.testing.assert_allclose(y_obs, y_toy_test) + np.testing.assert_allclose(y_obs, y_toy) @pytest.mark.parametrize("y_pred", [np.array(y_pred_list), y_pred_list]) @@ -196,13 +192,13 @@ def test_gamma_conformity_score_check_predicted_value( match=r".*At least one of the predicted target is negative.*" ): gamma_conf_score.get_signed_conformity_scores( - X_toy, y_toy_test, y_pred + X_toy, y_toy, y_pred ) with pytest.raises( ValueError, match=r".*At least one of the predicted target is negative.*" ): - gamma_conf_score.get_estimation_distribution(X_toy_test, y_pred, + gamma_conf_score.get_estimation_distribution(X_toy, y_pred, conf_scores) @@ -213,12 +209,12 @@ def test_check_consistency() -> None: """ dummy_conf_score = DummyConformityScore() conformity_scores = dummy_conf_score.get_signed_conformity_scores( - X_toy_test, y_toy_test, y_pred_list + X_toy, y_toy, y_pred_list ) with pytest.raises( ValueError, match=r".*The two functions get_conformity_scores.*" ): dummy_conf_score.check_consistency( - X_toy_test, y_toy_test, y_pred_list, conformity_scores + X_toy, y_toy, y_pred_list, conformity_scores ) From 330bbfbee7b1923c1ea5a917a39fb33b82e7a912 Mon Sep 17 00:00:00 2001 From: Candice Moyet Date: Thu, 6 Jul 2023 16:24:20 +0200 Subject: [PATCH 15/30] UPD: take TCO comments into account --- mapie/regression/regression.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/mapie/regression/regression.py b/mapie/regression/regression.py index 087a1fa2..5c4c9459 100644 --- a/mapie/regression/regression.py +++ b/mapie/regression/regression.py @@ -540,7 +540,11 @@ def predict( self._check_ensemble(ensemble) alpha = cast(Optional[NDArray], check_alpha(alpha)) - if not (alpha is None): + if alpha is None: + y_pred, _, _ = self.estimator_.predict(X, ensemble) + return y_pred + + else: n = len(self.conformity_scores_) alpha_np = cast(NDArray, alpha) check_alpha_and_n_samples(alpha_np, n) @@ -555,7 +559,3 @@ def predict( self.method ) return y_pred, np.stack([y_pred_low, y_pred_up], axis=1) - - else: - y_pred, _, _ = self.estimator_.predict(X, ensemble) - return y_pred From 30b03df3f02b78bf7de1aad81c8bd085edd29204 Mon Sep 17 00:00:00 2001 From: Candice Moyet Date: Thu, 6 Jul 2023 16:25:35 +0200 Subject: [PATCH 16/30] UPD: delete useless instance in test_conformity_scores.py --- mapie/tests/test_conformity_scores.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/mapie/tests/test_conformity_scores.py b/mapie/tests/test_conformity_scores.py index 66dc896f..be4b942a 100644 --- a/mapie/tests/test_conformity_scores.py +++ b/mapie/tests/test_conformity_scores.py @@ -35,18 +35,6 @@ def get_estimation_distribution(self, X: ArrayLike, y_pred: ArrayLike, return np.add(y_pred, conformity_scores) + 1 -estimator_toy = EnsembleRegressor( - LinearRegression(), - "plus", - KFold(n_splits=5, random_state=None, shuffle=True), - "mean", - None, - random_state, - 0.20, - False -) - - @pytest.mark.parametrize("sym", [False, True]) def test_error_mother_class_initialization(sym: bool) -> None: with pytest.raises(TypeError): From f7a4ca87718217f81962a2aeacaf9bcc1f85e5e4 Mon Sep 17 00:00:00 2001 From: Candice Moyet Date: Thu, 6 Jul 2023 18:32:00 +0200 Subject: [PATCH 17/30] FIX: delete unused imports --- mapie/tests/test_conformity_scores.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/mapie/tests/test_conformity_scores.py b/mapie/tests/test_conformity_scores.py index be4b942a..4fb7f635 100644 --- a/mapie/tests/test_conformity_scores.py +++ b/mapie/tests/test_conformity_scores.py @@ -4,9 +4,6 @@ from mapie._typing import ArrayLike, NDArray from mapie.conformity_scores import (AbsoluteConformityScore, ConformityScore, GammaConformityScore) -from mapie.regression.estimator import EnsembleRegressor -from sklearn.linear_model import LinearRegression -from sklearn.model_selection import KFold X_toy = np.array([0, 1, 2, 3, 4, 5]).reshape(-1, 1) y_toy = np.array([5, 7, 9, 11, 13, 15]) From 14175bbf33e205e2dd29372a9b95dd37eec0abb4 Mon Sep 17 00:00:00 2001 From: Candice Moyet Date: Thu, 6 Jul 2023 18:32:23 +0200 Subject: [PATCH 18/30] UPD: add shapes to docstrings --- mapie/conformity_scores/conformity_scores.py | 66 +++++++++++--------- 1 file changed, 37 insertions(+), 29 deletions(-) diff --git a/mapie/conformity_scores/conformity_scores.py b/mapie/conformity_scores/conformity_scores.py index de86fe92..b8073bb4 100644 --- a/mapie/conformity_scores/conformity_scores.py +++ b/mapie/conformity_scores/conformity_scores.py @@ -69,18 +69,18 @@ def get_signed_conformity_scores( Parameters ---------- - X: ArrayLike + X: ArrayLike of shape (n_samples_calib, n_features) Observed feature values. - y: ArrayLike + y: ArrayLike of shape (n_samples_calib,) Observed target values. - y_pred: ArrayLike + y_pred: ArrayLike of shape (n_samples_calib,) Predicted target values. Returns ------- - NDArray + NDArray of shape (n_samples_calib,) Signed conformity scores. """ @@ -101,20 +101,24 @@ def get_estimation_distribution( Parameters ---------- - X: ArrayLike + X: ArrayLike of shape (n_samples_calib, n_features) Observed feature values. y_pred: ArrayLike - Predicted reference values of shape (n_samples, ...). - The last dimension is the reference of the prediction. + The shape is either (n_samples_calib, n_samples_train): when the + method is called in ``get_bounds`` it needs a prediction per train + sample for each test sample to compute the bounds. + Or (n_samples_calib, 1): when it is called in ``check_consistency`` conformity_scores: ArrayLike - Either the conformity scores or the quantile of the conformity - scores aggregated. + The shape is either (n_samples_calib, n_alpha) when it is the + conformity scores themselves or (n_alpha, 1) when it is only the + quantile of the conformity scores. Returns ------- - NDArray + NDArray of shape (n_samples_calib, n_alpha) or + (n_samples_calib, n_samples_train) according to the shape of ``y_pred`` Observed values. """ @@ -138,16 +142,16 @@ def check_consistency( Parameters ---------- - X: ArrayLike + X: ArrayLike of shape (n_samples_calib, n_features) Observed feature values. - y: ArrayLike + y: ArrayLike of shape (n_samples_calib,) Observed target values. - y_pred: ArrayLike + y_pred: ArrayLike of shape (n_samples_calib,) Predicted target values. - conformity_scores: ArrayLike + conformity_scores: ArrayLike of shape (n_samples_calib,) Conformity scores. Raises @@ -155,8 +159,9 @@ def check_consistency( ValueError If the two methods are not consistent. """ - score_distribution = self.get_estimation_distribution(X, y_pred, - conformity_scores) + score_distribution = self.get_estimation_distribution( + X, y_pred, conformity_scores + ) abs_conformity_scores = np.abs(np.subtract(score_distribution, y)) max_conf_score = np.max(abs_conformity_scores) if max_conf_score > self.eps: @@ -183,18 +188,18 @@ def get_conformity_scores( Parameters ---------- - X: NDArray + X: NDArray of shape (n_samples_calib, n_features) Observed feature values. - y: NDArray + y: NDArray of shape (n_samples_calib,) Observed target values. - y_pred: NDArray + y_pred: NDArray of shape (n_samples_calib,) Predicted target values. Returns ------- - NDArray + NDArray of shape (n_samples_calib, 1) Conformity scores. """ conformity_scores = self.get_signed_conformity_scores(X, y, y_pred) @@ -217,12 +222,13 @@ def get_quantile( Parameters ---------- - values: NDArray + values: NDArray of shape (n_samples_calib, n_alpha) or + (n_samples_calib, n_samples_train) Values from which the quantile is computed, it can be the conformity scores or the conformity scores aggregated with the predictions. - alpha_np: NDArray + alpha_np: NDArray of shape (n_alpha,) NDArray of floats between ``0`` and ``1``, represents the uncertainty of the confidence interval. @@ -234,7 +240,7 @@ def get_quantile( Returns ------- - NDArray + NDArray of shape (n_alpha,) The quantile of the conformity scores. """ quantile = np.column_stack([ @@ -263,16 +269,16 @@ def get_bounds( Parameters ---------- - X: ArrayLike + X: ArrayLike of shape (n_samples_test, n_features) Observed feature values. estimator: RegressorMixin Estimator that is fitted to predict y from X. - conformity_scores: ArrayLike + conformity_scores: ArrayLike of shape (n_samples_calib,) Conformity scores. - alpha_np: NDArray + alpha_np: NDArray of shape (n_alpha,) NDArray of floats between ``0`` and ``1``, represents the uncertainty of the confidence interval. @@ -289,9 +295,11 @@ def get_bounds( Returns ------- Tuple[NDArray, NDArray, NDArray] - - The predictions itself. (y_pred) - - The lower bounds of the prediction intervals. - - The upper bounds of the prediction intervals. + - The predictions itself. (y_pred) of shape (n_samples_test,). + - The lower bounds of the prediction intervals of shape + (n_samples_test,). + - The upper bounds of the prediction intervals of shape + (n_samples_test,). """ y_pred, y_pred_low, y_pred_up = estimator.predict(X, ensemble) signed = -1 if self.sym else 1 From 437f4a703718ef687b07fed6a6f89d60858c47df Mon Sep 17 00:00:00 2001 From: Candice Moyet Date: Fri, 7 Jul 2023 13:55:04 +0200 Subject: [PATCH 19/30] UPD: two types of return possible in predict --- mapie/regression/estimator.py | 21 ++++++++++++++------- mapie/regression/regression.py | 8 +++++--- 2 files changed, 19 insertions(+), 10 deletions(-) diff --git a/mapie/regression/estimator.py b/mapie/regression/estimator.py index cc14d11f..c65fb26d 100644 --- a/mapie/regression/estimator.py +++ b/mapie/regression/estimator.py @@ -266,8 +266,8 @@ def _aggregate_with_mask( sample which one to integrate, and aggregate to produce phi-{t}(x_t) for each training sample x_t. - Parameters: - ----------- + Parameters + ---------- x: ArrayLike of shape (n_samples_test, n_estimators) Array of predictions, made by the refitted estimators, for each sample of the testing set. @@ -277,8 +277,8 @@ def _aggregate_with_mask( of a given estimator into the aggregation, for each training sample. - Returns: - -------- + Returns + ------- ArrayLike of shape (n_samples_test,) Array of aggregated predictions for each testing sample. """ @@ -450,7 +450,8 @@ def fit( def predict( self, X: ArrayLike, - ensemble: bool = False + ensemble: bool = False, + return_multi_pred: bool = True ) -> Union[NDArray, Tuple[NDArray, NDArray, NDArray]]: """ Predict target from X. It also computes the prediction per train sample @@ -473,6 +474,9 @@ def predict( By default ``False``. + return_multi_pred: bool + + Returns ------- Tuple[NDArray, NDArray, NDArray] @@ -480,7 +484,6 @@ def predict( - The multiple predictions for the lower bound of the intervals. - The multiple predictions for the upper bound of the intervals. """ - check_is_fitted(self, self.fit_attributes) y_pred = self.single_estimator_.predict(X) @@ -500,4 +503,8 @@ def predict( if ensemble: y_pred = aggregate_all(self.agg_function, y_pred_multi) - return y_pred, y_pred_multi_low, y_pred_multi_up + + if return_multi_pred: + return y_pred, y_pred_multi_low, y_pred_multi_up + else: + return y_pred diff --git a/mapie/regression/regression.py b/mapie/regression/regression.py index 5c4c9459..78141a75 100644 --- a/mapie/regression/regression.py +++ b/mapie/regression/regression.py @@ -541,8 +541,10 @@ def predict( alpha = cast(Optional[NDArray], check_alpha(alpha)) if alpha is None: - y_pred, _, _ = self.estimator_.predict(X, ensemble) - return y_pred + y_pred = self.estimator_.predict( + X, ensemble, return_multi_pred=False + ) + return np.array(y_pred) else: n = len(self.conformity_scores_) @@ -558,4 +560,4 @@ def predict( ensemble, self.method ) - return y_pred, np.stack([y_pred_low, y_pred_up], axis=1) + return np.array(y_pred), np.stack([y_pred_low, y_pred_up], axis=1) From 7eb73606eb1b4ca7a1430a3f5063ed380009d30c Mon Sep 17 00:00:00 2001 From: Candice Moyet Date: Fri, 7 Jul 2023 13:57:11 +0200 Subject: [PATCH 20/30] UPD: error in docstring --- mapie/regression/estimator.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mapie/regression/estimator.py b/mapie/regression/estimator.py index c65fb26d..ba787871 100644 --- a/mapie/regression/estimator.py +++ b/mapie/regression/estimator.py @@ -327,8 +327,7 @@ def _pred_multi(self, X: ArrayLike) -> NDArray: def predict_calib(self, X: ArrayLike) -> NDArray: """ - Perform predictions on X : the calibration set. This method is - called in the ConformityScore class to compute the conformity scores. + Perform predictions on X : the calibration set. Parameters ---------- From 90eb15263d5aff6462c93a57d72de2d38e5c49f3 Mon Sep 17 00:00:00 2001 From: Candice Moyet Date: Fri, 7 Jul 2023 14:14:17 +0200 Subject: [PATCH 21/30] UPD: take TCO comment into account --- mapie/tests/test_regression.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mapie/tests/test_regression.py b/mapie/tests/test_regression.py index 2d62d94f..5dd6f93a 100644 --- a/mapie/tests/test_regression.py +++ b/mapie/tests/test_regression.py @@ -534,8 +534,8 @@ def test_aggregate_with_mask_with_invalid_agg_function() -> None: def test_pred_loof_isnan() -> None: """Test that if validation set is empty then prediction is empty.""" mapie_reg = MapieRegressor() - mapie_reg = mapie_reg.fit(X, y) y_pred: NDArray + mapie_reg = mapie_reg.fit(X, y) y_pred, _ = mapie_reg.estimator_._predict_oof_estimator( estimator=LinearRegression(), X=X_toy, From 1dbf7c69b37b60a3f904d0d2d3a0b3c99a82305d Mon Sep 17 00:00:00 2001 From: Candice Moyet Date: Fri, 7 Jul 2023 15:41:37 +0200 Subject: [PATCH 22/30] ENH: add estimator module --- mapie/conformity_scores/conformity_scores.py | 8 +- mapie/estimator/__init__.py | 0 mapie/{regression => estimator}/estimator.py | 3 +- mapie/estimator/interface.py | 340 +++++++++++++++++++ mapie/regression/regression.py | 2 +- mapie/tests/test_regression.py | 2 +- 6 files changed, 348 insertions(+), 7 deletions(-) create mode 100644 mapie/estimator/__init__.py rename mapie/{regression => estimator}/estimator.py (99%) create mode 100644 mapie/estimator/interface.py diff --git a/mapie/conformity_scores/conformity_scores.py b/mapie/conformity_scores/conformity_scores.py index b8073bb4..d667d2a0 100644 --- a/mapie/conformity_scores/conformity_scores.py +++ b/mapie/conformity_scores/conformity_scores.py @@ -2,10 +2,10 @@ import numpy as np from typing import Tuple -from sklearn.base import RegressorMixin from mapie._compatibility import np_nanquantile from mapie._typing import ArrayLike, NDArray +from mapie.estimator.interface import EnsembleEstimator class ConformityScore(metaclass=ABCMeta): @@ -257,7 +257,7 @@ def get_quantile( def get_bounds( self, X: ArrayLike, - estimator: RegressorMixin, + estimator: EnsembleEstimator, conformity_scores: NDArray, alpha_np: NDArray, ensemble: bool, @@ -265,14 +265,14 @@ def get_bounds( ) -> Tuple[NDArray, NDArray, NDArray]: """ Compute bounds of the prediction intervals from the observed values, - the estimator of type ``EnsembleRegressor`` and the conformity scores. + the estimator of type ``EnsembleEstimator`` and the conformity scores. Parameters ---------- X: ArrayLike of shape (n_samples_test, n_features) Observed feature values. - estimator: RegressorMixin + estimator: EnsembleEstimator Estimator that is fitted to predict y from X. conformity_scores: ArrayLike of shape (n_samples_calib,) diff --git a/mapie/estimator/__init__.py b/mapie/estimator/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/mapie/regression/estimator.py b/mapie/estimator/estimator.py similarity index 99% rename from mapie/regression/estimator.py rename to mapie/estimator/estimator.py index ba787871..167bd2a6 100644 --- a/mapie/regression/estimator.py +++ b/mapie/estimator/estimator.py @@ -13,9 +13,10 @@ from mapie.aggregation_functions import aggregate_all, phi2D from mapie.utils import (check_nan_in_aposteriori_prediction, fit_estimator) +from mapie.estimator.interface import EnsembleEstimator -class EnsembleRegressor(RegressorMixin): +class EnsembleRegressor(EnsembleEstimator): """ This class implements methods to handle the training and usage of the estimator. This estimator can be unique or composed by cross validated diff --git a/mapie/estimator/interface.py b/mapie/estimator/interface.py new file mode 100644 index 00000000..25435fe6 --- /dev/null +++ b/mapie/estimator/interface.py @@ -0,0 +1,340 @@ +from __future__ import annotations + +from typing import Optional, Tuple, Union + +from sklearn.base import RegressorMixin + +from mapie._typing import ArrayLike, NDArray + + +class EnsembleEstimator(RegressorMixin): + """ + This class implements methods to handle the training and usage of the + estimator. This estimator can be unique or composed by cross validated + estimators. + + Parameters + ---------- + estimator: Optional[RegressorMixin] + Any regressor with scikit-learn API + (i.e. with ``fit`` and ``predict`` methods). + If ``None``, estimator defaults to a ``LinearRegression`` instance. + + By default ``None``. + + method: str + Method to choose for prediction interval estimates. + Choose among: + + - ``"naive"``, based on training set conformity scores, + - ``"base"``, based on validation sets conformity scores, + - ``"plus"``, based on validation conformity scores and + testing predictions, + - ``"minmax"``, based on validation conformity scores and + testing predictions (min/max among cross-validation clones). + + By default ``"plus"``. + + cv: Optional[Union[int, str, BaseCrossValidator]] + The cross-validation strategy for computing conformity scores. + It directly drives the distinction between jackknife and cv variants. + Choose among: + + - ``None``, to use the default 5-fold cross-validation + - integer, to specify the number of folds. + If equal to ``-1``, equivalent to + ``sklearn.model_selection.LeaveOneOut()``. + - CV splitter: any ``sklearn.model_selection.BaseCrossValidator`` + Main variants are: + - ``sklearn.model_selection.LeaveOneOut`` (jackknife), + - ``sklearn.model_selection.KFold`` (cross-validation), + - ``subsample.Subsample`` object (bootstrap). + - ``"split"``, does not involve cross-validation but a division + of the data into training and calibration subsets. The splitter + used is the following: ``sklearn.model_selection.ShuffleSplit``. + - ``"prefit"``, assumes that ``estimator`` has been fitted already, + and the ``method`` parameter is ignored. + All data provided in the ``fit`` method is then used + for computing conformity scores only. + At prediction time, quantiles of these conformity scores are used + to provide a prediction interval with fixed width. + The user has to take care manually that data for model fitting and + conformity scores estimate are disjoint. + + By default ``None``. + + test_size: Optional[Union[int, float]] + If ``float``, should be between ``0.0`` and ``1.0`` and represent the + proportion of the dataset to include in the test split. If ``int``, + represents the absolute number of test samples. If ``None``, + it will be set to ``0.1``. + + If cv is not ``"split"``, ``test_size`` is ignored. + + By default ``None``. + + n_jobs: Optional[int] + Number of jobs for parallel processing using joblib + via the "locky" backend. + If ``-1`` all CPUs are used. + If ``1`` is given, no parallel computing code is used at all, + which is useful for debugging. + For ``n_jobs`` below ``-1``, ``(n_cpus + 1 - n_jobs)`` are used. + ``None`` is a marker for `unset` that will be interpreted as + ``n_jobs=1`` (sequential execution). + + By default ``None``. + + agg_function: Optional[str] + Determines how to aggregate predictions from perturbed models, both at + training and prediction time. + + If ``None``, it is ignored except if ``cv`` class is ``Subsample``, + in which case an error is raised. + If ``"mean"`` or ``"median"``, returns the mean or median of the + predictions computed from the out-of-folds models. + Note: if you plan to set the ``ensemble`` argument to ``True`` in the + ``predict`` method, you have to specify an aggregation function. + Otherwise an error would be raised. + + The Jackknife+ interval can be interpreted as an interval around the + median prediction, and is guaranteed to lie inside the interval, + unlike the single estimator predictions. + + When the cross-validation strategy is ``Subsample`` (i.e. for the + Jackknife+-after-Bootstrap method), this function is also used to + aggregate the training set in-sample predictions. + + If ``cv`` is ``"prefit"`` or ``"split"``, ``agg_function`` is ignored. + + By default ``"mean"``. + + verbose: int + The verbosity level, used with joblib for multiprocessing. + The frequency of the messages increases with the verbosity level. + If it more than ``10``, all iterations are reported. + Above ``50``, the output is sent to stdout. + + By default ``0``. + + random_state: Optional[Union[int, RandomState]] + Pseudo random number generator state used for random sampling. + Pass an int for reproducible output across multiple function calls. + + By default ``None``. + + Attributes + ---------- + single_estimator_: sklearn.RegressorMixin + Estimator fitted on the whole training set. + + estimators_: list + List of out-of-folds estimators. + + k_: ArrayLike + - Array of nans, of shape (len(y), 1) if ``cv`` is ``"prefit"`` + (defined but not used) + - Dummy array of folds containing each training sample, otherwise. + Of shape (n_samples_train, cv.get_n_splits(X_train, y_train)). + """ + no_agg_cv_ = ["prefit", "split"] + no_agg_methods_ = ["naive", "base"] + fit_attributes = [ + "single_estimator_", + "estimators_", + "k_", + ] + + @staticmethod + def _fit_oof_estimator( + estimator: RegressorMixin, + X: ArrayLike, + y: ArrayLike, + train_index: ArrayLike, + sample_weight: Optional[ArrayLike] = None, + ) -> RegressorMixin: + """ + Fit a single out-of-fold model on a given training set. + + Parameters + ---------- + estimator: RegressorMixin + Estimator to train. + + X: ArrayLike of shape (n_samples, n_features) + Input data. + + y: ArrayLike of shape (n_samples,) + Input labels. + + train_index: ArrayLike of shape (n_samples_train) + Training data indices. + + sample_weight: Optional[ArrayLike] of shape (n_samples,) + Sample weights. If None, then samples are equally weighted. + By default ``None``. + + Returns + ------- + Tuple[RegressorMixin, NDArray, ArrayLike] + + - [0]: RegressorMixin, fitted estimator + - [1]: NDArray of shape (n_samples_val,), + estimator predictions on the validation fold. + - [2]: ArrayLike of shape (n_samples_val,), + validation data indices. + """ + + @staticmethod + def _predict_oof_estimator( + estimator: RegressorMixin, + X: ArrayLike, + val_index: ArrayLike, + ): + """ + Perform predictions on a single out-of-fold model on a validation set. + + Parameters + ---------- + estimator: RegressorMixin + Estimator to train. + + X: ArrayLike of shape (n_samples, n_features) + Input data. + + val_index: ArrayLike of shape (n_samples_val) + Validation data indices. + + Returns + ------- + Tuple[NDArray, ArrayLike] + Predictions of estimator from val_index of X. + """ + + def _aggregate_with_mask( + self, + x: NDArray, + k: NDArray + ) -> NDArray: + """ + Take the array of predictions, made by the refitted estimators, + on the testing set, and the 1-or-nan array indicating for each training + sample which one to integrate, and aggregate to produce phi-{t}(x_t) + for each training sample x_t. + + Parameters + ---------- + x: ArrayLike of shape (n_samples_test, n_estimators) + Array of predictions, made by the refitted estimators, + for each sample of the testing set. + + k: ArrayLike of shape (n_samples_training, n_estimators) + 1-or-nan array: indicates whether to integrate the prediction + of a given estimator into the aggregation, for each training + sample. + + Returns + ------- + ArrayLike of shape (n_samples_test,) + Array of aggregated predictions for each testing sample. + """ + + def _pred_multi(self, X: ArrayLike) -> NDArray: + """ + Return a prediction per train sample for each test sample, by + aggregation with matrix ``k_``. + + Parameters + ---------- + X: ArrayLike of shape (n_samples_test, n_features) + Input data + + Returns + ------- + NDArray of shape (n_samples_test, n_samples_train) + """ + + def predict_calib(self, X: ArrayLike) -> NDArray: + """ + Perform predictions on X : the calibration set. This method is + called in the ConformityScore class to compute the conformity scores. + + Parameters + ---------- + X: ArrayLike of shape (n_samples_test, n_features) + Input data + + Returns + ------- + NDArray of shape (n_samples_test, 1) + The predictions. + """ + + def fit( + self, + X: ArrayLike, + y: ArrayLike, + sample_weight: Optional[ArrayLike] = None, + ) -> EnsembleEstimator: + """ + Fit the base estimator under the ``single_estimator_`` attribute. + Fit all cross-validated estimator clones + and rearrange them into a list, the ``estimators_`` attribute. + Out-of-fold conformity scores are stored under + the ``conformity_scores_`` attribute. + + Parameters + ---------- + X: ArrayLike of shape (n_samples, n_features) + Input data. + + y: ArrayLike of shape (n_samples,) + Input labels. + + sample_weight: Optional[ArrayLike] of shape (n_samples,) + Sample weights. If None, then samples are equally weighted. + By default ``None``. + + Returns + ------- + EnsembleRegressor + The estimator fitted. + """ + + def predict( + self, + X: ArrayLike, + ensemble: bool = False, + return_multi_pred: bool = True + ) -> Union[NDArray, Tuple[NDArray, NDArray, NDArray]]: + """ + Predict target from X. It also computes the prediction per train sample + for each test sample according to ``self.method``. + + Parameters + ---------- + X: ArrayLike of shape (n_samples, n_features) + Test data. + + ensemble: bool + Boolean determining whether the predictions are ensembled or not. + If ``False``, predictions are those of the model trained on the + whole training set. + If ``True``, predictions from perturbed models are aggregated by + the aggregation function specified in the ``agg_function`` + attribute. + + If ``cv`` is ``"prefit"`` or ``"split"``, ``ensemble`` is ignored. + + By default ``False``. + + return_multi_pred: bool + + + Returns + ------- + Tuple[NDArray, NDArray, NDArray] + - Predictions + - The multiple predictions for the lower bound of the intervals. + - The multiple predictions for the upper bound of the intervals. + """ diff --git a/mapie/regression/regression.py b/mapie/regression/regression.py index 78141a75..e51fbecc 100644 --- a/mapie/regression/regression.py +++ b/mapie/regression/regression.py @@ -13,7 +13,7 @@ from mapie._typing import ArrayLike, NDArray from mapie.conformity_scores import ConformityScore -from .estimator import EnsembleRegressor +from mapie.estimator.estimator import EnsembleRegressor from mapie.utils import (check_alpha, check_alpha_and_n_samples, check_conformity_score, check_cv, check_estimator_fit_predict, check_n_features_in, diff --git a/mapie/tests/test_regression.py b/mapie/tests/test_regression.py index 5dd6f93a..e5a8353e 100644 --- a/mapie/tests/test_regression.py +++ b/mapie/tests/test_regression.py @@ -24,7 +24,7 @@ GammaConformityScore) from mapie.metrics import regression_coverage_score from mapie.regression import MapieRegressor -from mapie.regression.estimator import EnsembleRegressor +from mapie.estimator.estimator import EnsembleRegressor from mapie.subsample import Subsample X_toy = np.array([0, 1, 2, 3, 4, 5]).reshape(-1, 1) From 5cc17fe9a841e21db9d98bed2de33c9fd2934a57 Mon Sep 17 00:00:00 2001 From: Candice Moyet Date: Mon, 10 Jul 2023 10:29:42 +0200 Subject: [PATCH 23/30] UPD: dosctrings + beautify --- mapie/conformity_scores/conformity_scores.py | 54 +++++++++---------- .../residual_conformity_scores.py | 8 +-- mapie/tests/test_conformity_scores.py | 37 +++++++------ 3 files changed, 53 insertions(+), 46 deletions(-) diff --git a/mapie/conformity_scores/conformity_scores.py b/mapie/conformity_scores/conformity_scores.py index d667d2a0..f658dd40 100644 --- a/mapie/conformity_scores/conformity_scores.py +++ b/mapie/conformity_scores/conformity_scores.py @@ -69,18 +69,18 @@ def get_signed_conformity_scores( Parameters ---------- - X: ArrayLike of shape (n_samples_calib, n_features) + X: ArrayLike of shape (n_samples, n_features) Observed feature values. - y: ArrayLike of shape (n_samples_calib,) + y: ArrayLike of shape (n_samples,) Observed target values. - y_pred: ArrayLike of shape (n_samples_calib,) + y_pred: ArrayLike of shape (n_samples,) Predicted target values. Returns ------- - NDArray of shape (n_samples_calib,) + NDArray of shape (n_samples,) Signed conformity scores. """ @@ -101,24 +101,24 @@ def get_estimation_distribution( Parameters ---------- - X: ArrayLike of shape (n_samples_calib, n_features) + X: ArrayLike of shape (n_samples, n_features) Observed feature values. y_pred: ArrayLike - The shape is either (n_samples_calib, n_samples_train): when the + The shape is either (n_samples, n_references): when the method is called in ``get_bounds`` it needs a prediction per train sample for each test sample to compute the bounds. - Or (n_samples_calib, 1): when it is called in ``check_consistency`` + Or (n_samples,): when it is called in ``check_consistency`` conformity_scores: ArrayLike - The shape is either (n_samples_calib, n_alpha) when it is the - conformity scores themselves or (n_alpha, 1) when it is only the + The shape is either (n_samples, 1) when it is the + conformity scores themselves or (1, n_alpha) when it is only the quantile of the conformity scores. Returns ------- - NDArray of shape (n_samples_calib, n_alpha) or - (n_samples_calib, n_samples_train) according to the shape of ``y_pred`` + NDArray of shape (n_samples, n_alpha) or + (n_samples, n_references) according to the shape of ``y_pred`` Observed values. """ @@ -142,16 +142,16 @@ def check_consistency( Parameters ---------- - X: ArrayLike of shape (n_samples_calib, n_features) + X: ArrayLike of shape (n_samples, n_features) Observed feature values. - y: ArrayLike of shape (n_samples_calib,) + y: ArrayLike of shape (n_samples,) Observed target values. - y_pred: ArrayLike of shape (n_samples_calib,) + y_pred: ArrayLike of shape (n_samples,) Predicted target values. - conformity_scores: ArrayLike of shape (n_samples_calib,) + conformity_scores: ArrayLike of shape (n_samples,) Conformity scores. Raises @@ -188,18 +188,18 @@ def get_conformity_scores( Parameters ---------- - X: NDArray of shape (n_samples_calib, n_features) + X: NDArray of shape (n_samples, n_features) Observed feature values. - y: NDArray of shape (n_samples_calib,) + y: NDArray of shape (n_samples,) Observed target values. - y_pred: NDArray of shape (n_samples_calib,) + y_pred: NDArray of shape (n_samples,) Predicted target values. Returns ------- - NDArray of shape (n_samples_calib, 1) + NDArray of shape (n_samples,) Conformity scores. """ conformity_scores = self.get_signed_conformity_scores(X, y, y_pred) @@ -222,8 +222,8 @@ def get_quantile( Parameters ---------- - values: NDArray of shape (n_samples_calib, n_alpha) or - (n_samples_calib, n_samples_train) + values: NDArray of shape (n_samples,) or + (n_samples, n_references) Values from which the quantile is computed, it can be the conformity scores or the conformity scores aggregated with the predictions. @@ -240,7 +240,7 @@ def get_quantile( Returns ------- - NDArray of shape (n_alpha,) + NDArray of shape (1, n_alpha) or (n_samples, n_alpha) The quantile of the conformity scores. """ quantile = np.column_stack([ @@ -269,13 +269,13 @@ def get_bounds( Parameters ---------- - X: ArrayLike of shape (n_samples_test, n_features) + X: ArrayLike of shape (n_samples, n_features) Observed feature values. estimator: EnsembleEstimator Estimator that is fitted to predict y from X. - conformity_scores: ArrayLike of shape (n_samples_calib,) + conformity_scores: ArrayLike of shape (n_samples,) Conformity scores. alpha_np: NDArray of shape (n_alpha,) @@ -295,11 +295,11 @@ def get_bounds( Returns ------- Tuple[NDArray, NDArray, NDArray] - - The predictions itself. (y_pred) of shape (n_samples_test,). + - The predictions itself. (y_pred) of shape (n_samples,). - The lower bounds of the prediction intervals of shape - (n_samples_test,). + (n_samples, n_alpha). - The upper bounds of the prediction intervals of shape - (n_samples_test,). + (n_samples, n_alpha). """ y_pred, y_pred_low, y_pred_up = estimator.predict(X, ensemble) signed = -1 if self.sym else 1 diff --git a/mapie/conformity_scores/residual_conformity_scores.py b/mapie/conformity_scores/residual_conformity_scores.py index 92772f32..8dbff9fd 100644 --- a/mapie/conformity_scores/residual_conformity_scores.py +++ b/mapie/conformity_scores/residual_conformity_scores.py @@ -29,8 +29,8 @@ def get_signed_conformity_scores( y_pred: ArrayLike, ) -> NDArray: """ - Compute the signed conformity scores from the observed values - and the predicted ones, from the following formula: + Compute the signed conformity scores from the predicted values + and the observed ones, from the following formula: signed conformity score = y - y_pred """ return np.subtract(y, y_pred) @@ -43,7 +43,7 @@ def get_estimation_distribution( ) -> NDArray: """ Compute samples of the estimation distribution from the predicted - targets and ``conformity_scores``, from the following formula: + values and the conformity scores, from the following formula: signed conformity score = y - y_pred <=> y = y_pred + signed conformity score @@ -123,7 +123,7 @@ def get_estimation_distribution( ) -> NDArray: """ Compute samples of the estimation distribution from the predicted - targets and ``conformity_scores``, from the following formula: + values and the conformity scores, from the following formula: signed conformity score = (y - y_pred) / y_pred <=> y = y_pred * (1 + signed conformity score) diff --git a/mapie/tests/test_conformity_scores.py b/mapie/tests/test_conformity_scores.py index 4fb7f635..a47cfbf4 100644 --- a/mapie/tests/test_conformity_scores.py +++ b/mapie/tests/test_conformity_scores.py @@ -22,8 +22,9 @@ def get_signed_conformity_scores( ) -> NDArray: return np.subtract(y, y_pred) - def get_estimation_distribution(self, X: ArrayLike, y_pred: ArrayLike, - conformity_scores: ArrayLike) -> NDArray: + def get_estimation_distribution( + self, X: ArrayLike, y_pred: ArrayLike, conformity_scores: ArrayLike + ) -> NDArray: """ A positive constant is added to the sum between predictions and conformity scores to make the estimated distribution inconsistent @@ -65,8 +66,9 @@ def test_absolute_conformity_score_get_estimation_distribution( ) -> None: """Test conformity observed value computation for AbsoluteConformityScore.""" # noqa: E501 abs_conf_score = AbsoluteConformityScore() - y_obs = abs_conf_score.get_estimation_distribution(X_toy, y_pred, - conf_scores) + y_obs = abs_conf_score.get_estimation_distribution( + X_toy, y_pred, conf_scores + ) np.testing.assert_allclose(y_obs, y_toy) @@ -77,8 +79,9 @@ def test_absolute_conformity_score_consistency(y_pred: NDArray) -> None: signed_conf_scores = abs_conf_score.get_signed_conformity_scores( X_toy, y_toy, y_pred ) - y_obs = abs_conf_score.get_estimation_distribution(X_toy, y_pred, - signed_conf_scores) + y_obs = abs_conf_score.get_estimation_distribution( + X_toy, y_pred, signed_conf_scores + ) np.testing.assert_allclose(y_obs, y_toy) @@ -89,7 +92,8 @@ def test_gamma_conformity_score_get_conformity_scores( """Test conformity score computation for GammaConformityScore.""" gamma_conf_score = GammaConformityScore() conf_scores = gamma_conf_score.get_conformity_scores( - X_toy, y_toy, y_pred) + X_toy, y_toy, y_pred + ) expected_signed_conf_scores = np.array(conf_scores_gamma_list) np.testing.assert_allclose(conf_scores, expected_signed_conf_scores) @@ -107,8 +111,9 @@ def test_gamma_conformity_score_get_estimation_distribution( ) -> None: """Test conformity observed value computation for GammaConformityScore.""" # noqa: E501 gamma_conf_score = GammaConformityScore() - y_obs = gamma_conf_score.get_estimation_distribution(X_toy, y_pred, - conf_scores) + y_obs = gamma_conf_score.get_estimation_distribution( + X_toy, y_pred, conf_scores + ) np.testing.assert_allclose(y_obs, y_toy) @@ -119,8 +124,9 @@ def test_gamma_conformity_score_consistency(y_pred: NDArray) -> None: signed_conf_scores = gamma_conf_score.get_signed_conformity_scores( X_toy, y_toy, y_pred ) - y_obs = gamma_conf_score.get_estimation_distribution(X_toy, y_pred, - signed_conf_scores) + y_obs = gamma_conf_score.get_estimation_distribution( + X_toy, y_pred, signed_conf_scores + ) np.testing.assert_allclose(y_obs, y_toy) @@ -183,14 +189,15 @@ def test_gamma_conformity_score_check_predicted_value( ValueError, match=r".*At least one of the predicted target is negative.*" ): - gamma_conf_score.get_estimation_distribution(X_toy, y_pred, - conf_scores) + gamma_conf_score.get_estimation_distribution( + X_toy, y_pred, conf_scores + ) def test_check_consistency() -> None: """ - Test that a dummy ConformityScore class that gives inconsistent - conformity scores and distributions raises an error. + Test that a dummy ConformityScore class that gives inconsistent scores + and distributions raises an error. """ dummy_conf_score = DummyConformityScore() conformity_scores = dummy_conf_score.get_signed_conformity_scores( From 1eadfa56283557eb6106232ad7eb5e84e6b5600f Mon Sep 17 00:00:00 2001 From: Candice Moyet Date: Mon, 10 Jul 2023 10:41:38 +0200 Subject: [PATCH 24/30] FIX: estimator interface --- mapie/estimator/estimator.py | 11 +- mapie/estimator/interface.py | 256 ----------------------------------- 2 files changed, 3 insertions(+), 264 deletions(-) diff --git a/mapie/estimator/estimator.py b/mapie/estimator/estimator.py index 167bd2a6..ed05b850 100644 --- a/mapie/estimator/estimator.py +++ b/mapie/estimator/estimator.py @@ -205,13 +205,8 @@ def _fit_oof_estimator( Returns ------- - Tuple[RegressorMixin, NDArray, ArrayLike] - - - [0]: RegressorMixin, fitted estimator - - [1]: NDArray of shape (n_samples_val,), - estimator predictions on the validation fold. - - [2]: ArrayLike of shape (n_samples_val,), - validation data indices. + RegressorMixin + Fitted estimator. """ X_train = _safe_indexing(X, train_index) y_train = _safe_indexing(y, train_index) @@ -229,7 +224,7 @@ def _predict_oof_estimator( estimator: RegressorMixin, X: ArrayLike, val_index: ArrayLike, - ): + ) -> Tuple[NDArray, ArrayLike]: """ Perform predictions on a single out-of-fold model on a validation set. diff --git a/mapie/estimator/interface.py b/mapie/estimator/interface.py index 25435fe6..dafbb382 100644 --- a/mapie/estimator/interface.py +++ b/mapie/estimator/interface.py @@ -12,263 +12,7 @@ class EnsembleEstimator(RegressorMixin): This class implements methods to handle the training and usage of the estimator. This estimator can be unique or composed by cross validated estimators. - - Parameters - ---------- - estimator: Optional[RegressorMixin] - Any regressor with scikit-learn API - (i.e. with ``fit`` and ``predict`` methods). - If ``None``, estimator defaults to a ``LinearRegression`` instance. - - By default ``None``. - - method: str - Method to choose for prediction interval estimates. - Choose among: - - - ``"naive"``, based on training set conformity scores, - - ``"base"``, based on validation sets conformity scores, - - ``"plus"``, based on validation conformity scores and - testing predictions, - - ``"minmax"``, based on validation conformity scores and - testing predictions (min/max among cross-validation clones). - - By default ``"plus"``. - - cv: Optional[Union[int, str, BaseCrossValidator]] - The cross-validation strategy for computing conformity scores. - It directly drives the distinction between jackknife and cv variants. - Choose among: - - - ``None``, to use the default 5-fold cross-validation - - integer, to specify the number of folds. - If equal to ``-1``, equivalent to - ``sklearn.model_selection.LeaveOneOut()``. - - CV splitter: any ``sklearn.model_selection.BaseCrossValidator`` - Main variants are: - - ``sklearn.model_selection.LeaveOneOut`` (jackknife), - - ``sklearn.model_selection.KFold`` (cross-validation), - - ``subsample.Subsample`` object (bootstrap). - - ``"split"``, does not involve cross-validation but a division - of the data into training and calibration subsets. The splitter - used is the following: ``sklearn.model_selection.ShuffleSplit``. - - ``"prefit"``, assumes that ``estimator`` has been fitted already, - and the ``method`` parameter is ignored. - All data provided in the ``fit`` method is then used - for computing conformity scores only. - At prediction time, quantiles of these conformity scores are used - to provide a prediction interval with fixed width. - The user has to take care manually that data for model fitting and - conformity scores estimate are disjoint. - - By default ``None``. - - test_size: Optional[Union[int, float]] - If ``float``, should be between ``0.0`` and ``1.0`` and represent the - proportion of the dataset to include in the test split. If ``int``, - represents the absolute number of test samples. If ``None``, - it will be set to ``0.1``. - - If cv is not ``"split"``, ``test_size`` is ignored. - - By default ``None``. - - n_jobs: Optional[int] - Number of jobs for parallel processing using joblib - via the "locky" backend. - If ``-1`` all CPUs are used. - If ``1`` is given, no parallel computing code is used at all, - which is useful for debugging. - For ``n_jobs`` below ``-1``, ``(n_cpus + 1 - n_jobs)`` are used. - ``None`` is a marker for `unset` that will be interpreted as - ``n_jobs=1`` (sequential execution). - - By default ``None``. - - agg_function: Optional[str] - Determines how to aggregate predictions from perturbed models, both at - training and prediction time. - - If ``None``, it is ignored except if ``cv`` class is ``Subsample``, - in which case an error is raised. - If ``"mean"`` or ``"median"``, returns the mean or median of the - predictions computed from the out-of-folds models. - Note: if you plan to set the ``ensemble`` argument to ``True`` in the - ``predict`` method, you have to specify an aggregation function. - Otherwise an error would be raised. - - The Jackknife+ interval can be interpreted as an interval around the - median prediction, and is guaranteed to lie inside the interval, - unlike the single estimator predictions. - - When the cross-validation strategy is ``Subsample`` (i.e. for the - Jackknife+-after-Bootstrap method), this function is also used to - aggregate the training set in-sample predictions. - - If ``cv`` is ``"prefit"`` or ``"split"``, ``agg_function`` is ignored. - - By default ``"mean"``. - - verbose: int - The verbosity level, used with joblib for multiprocessing. - The frequency of the messages increases with the verbosity level. - If it more than ``10``, all iterations are reported. - Above ``50``, the output is sent to stdout. - - By default ``0``. - - random_state: Optional[Union[int, RandomState]] - Pseudo random number generator state used for random sampling. - Pass an int for reproducible output across multiple function calls. - - By default ``None``. - - Attributes - ---------- - single_estimator_: sklearn.RegressorMixin - Estimator fitted on the whole training set. - - estimators_: list - List of out-of-folds estimators. - - k_: ArrayLike - - Array of nans, of shape (len(y), 1) if ``cv`` is ``"prefit"`` - (defined but not used) - - Dummy array of folds containing each training sample, otherwise. - Of shape (n_samples_train, cv.get_n_splits(X_train, y_train)). """ - no_agg_cv_ = ["prefit", "split"] - no_agg_methods_ = ["naive", "base"] - fit_attributes = [ - "single_estimator_", - "estimators_", - "k_", - ] - - @staticmethod - def _fit_oof_estimator( - estimator: RegressorMixin, - X: ArrayLike, - y: ArrayLike, - train_index: ArrayLike, - sample_weight: Optional[ArrayLike] = None, - ) -> RegressorMixin: - """ - Fit a single out-of-fold model on a given training set. - - Parameters - ---------- - estimator: RegressorMixin - Estimator to train. - - X: ArrayLike of shape (n_samples, n_features) - Input data. - - y: ArrayLike of shape (n_samples,) - Input labels. - - train_index: ArrayLike of shape (n_samples_train) - Training data indices. - - sample_weight: Optional[ArrayLike] of shape (n_samples,) - Sample weights. If None, then samples are equally weighted. - By default ``None``. - - Returns - ------- - Tuple[RegressorMixin, NDArray, ArrayLike] - - - [0]: RegressorMixin, fitted estimator - - [1]: NDArray of shape (n_samples_val,), - estimator predictions on the validation fold. - - [2]: ArrayLike of shape (n_samples_val,), - validation data indices. - """ - - @staticmethod - def _predict_oof_estimator( - estimator: RegressorMixin, - X: ArrayLike, - val_index: ArrayLike, - ): - """ - Perform predictions on a single out-of-fold model on a validation set. - - Parameters - ---------- - estimator: RegressorMixin - Estimator to train. - - X: ArrayLike of shape (n_samples, n_features) - Input data. - - val_index: ArrayLike of shape (n_samples_val) - Validation data indices. - - Returns - ------- - Tuple[NDArray, ArrayLike] - Predictions of estimator from val_index of X. - """ - - def _aggregate_with_mask( - self, - x: NDArray, - k: NDArray - ) -> NDArray: - """ - Take the array of predictions, made by the refitted estimators, - on the testing set, and the 1-or-nan array indicating for each training - sample which one to integrate, and aggregate to produce phi-{t}(x_t) - for each training sample x_t. - - Parameters - ---------- - x: ArrayLike of shape (n_samples_test, n_estimators) - Array of predictions, made by the refitted estimators, - for each sample of the testing set. - - k: ArrayLike of shape (n_samples_training, n_estimators) - 1-or-nan array: indicates whether to integrate the prediction - of a given estimator into the aggregation, for each training - sample. - - Returns - ------- - ArrayLike of shape (n_samples_test,) - Array of aggregated predictions for each testing sample. - """ - - def _pred_multi(self, X: ArrayLike) -> NDArray: - """ - Return a prediction per train sample for each test sample, by - aggregation with matrix ``k_``. - - Parameters - ---------- - X: ArrayLike of shape (n_samples_test, n_features) - Input data - - Returns - ------- - NDArray of shape (n_samples_test, n_samples_train) - """ - - def predict_calib(self, X: ArrayLike) -> NDArray: - """ - Perform predictions on X : the calibration set. This method is - called in the ConformityScore class to compute the conformity scores. - - Parameters - ---------- - X: ArrayLike of shape (n_samples_test, n_features) - Input data - - Returns - ------- - NDArray of shape (n_samples_test, 1) - The predictions. - """ def fit( self, From 5e052cd661462083e8b4d71b5f7d56a5004d12d2 Mon Sep 17 00:00:00 2001 From: Candice Moyet Date: Mon, 10 Jul 2023 10:42:02 +0200 Subject: [PATCH 25/30] UPD: parameter's name --- mapie/conformity_scores/conformity_scores.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/mapie/conformity_scores/conformity_scores.py b/mapie/conformity_scores/conformity_scores.py index f658dd40..ef4a79ad 100644 --- a/mapie/conformity_scores/conformity_scores.py +++ b/mapie/conformity_scores/conformity_scores.py @@ -211,7 +211,7 @@ def get_conformity_scores( @staticmethod def get_quantile( - values: NDArray, + conformity_scores: NDArray, alpha_np: NDArray, axis: int, method: str @@ -222,7 +222,7 @@ def get_quantile( Parameters ---------- - values: NDArray of shape (n_samples,) or + conformity_scores: NDArray of shape (n_samples,) or (n_samples, n_references) Values from which the quantile is computed, it can be the conformity scores or the conformity scores aggregated with @@ -245,7 +245,7 @@ def get_quantile( """ quantile = np.column_stack([ np_nanquantile( - values.astype(float), + conformity_scores.astype(float), _alpha, axis=axis, method=method @@ -308,17 +308,17 @@ def get_bounds( alpha_low = alpha_np if self.sym else alpha_np / 2 alpha_up = 1 - alpha_np if self.sym else 1 - alpha_np / 2 - values_low = self.get_estimation_distribution( + conformity_scores_low = self.get_estimation_distribution( X, y_pred_low, signed * conformity_scores ) - values_up = self.get_estimation_distribution( + conformity_scores_up = self.get_estimation_distribution( X, y_pred_up, conformity_scores ) bound_low = self.get_quantile( - values_low, alpha_low, axis=1, method="lower" + conformity_scores_low, alpha_low, axis=1, method="lower" ) bound_up = self.get_quantile( - values_up, alpha_up, axis=1, method="higher" + conformity_scores_up, alpha_up, axis=1, method="higher" ) else: quantile_search = "higher" if self.sym else "lower" From 0662338fb500b2cb3df05eeb16712597ecd0a9c2 Mon Sep 17 00:00:00 2001 From: Candice Moyet Date: Mon, 10 Jul 2023 12:16:02 +0200 Subject: [PATCH 26/30] UPD: history.rst --- HISTORY.rst | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/HISTORY.rst b/HISTORY.rst index e1e13fc3..620eebdb 100644 --- a/HISTORY.rst +++ b/HISTORY.rst @@ -4,7 +4,8 @@ History ##### (##########) ------------------ - +* Refactor MapieRegressor and ConformityScore to add the possibility to use X in ConformityScore. +* Separate the handling of the estimator from MapieRegressor into a new class called EnsembleEstimator. * Add grouped conditional coverage metrics named SSC for regression and classification * Add HSIC metric for regression * Migrate conformity scores classes into conformity_scores module From 3d1b774ffec0ddc2af5db71b5bf4ee231d691c94 Mon Sep 17 00:00:00 2001 From: Candice Moyet Date: Tue, 11 Jul 2023 15:40:54 +0200 Subject: [PATCH 27/30] FIX: doc random state and prediction type --- .../plot_compare_conformity_scores.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/examples/regression/1-quickstart/plot_compare_conformity_scores.py b/examples/regression/1-quickstart/plot_compare_conformity_scores.py index 7d345b63..e4b79c70 100644 --- a/examples/regression/1-quickstart/plot_compare_conformity_scores.py +++ b/examples/regression/1-quickstart/plot_compare_conformity_scores.py @@ -39,7 +39,7 @@ from mapie.metrics import regression_coverage_score from mapie.regression import MapieRegressor -np.random.seed(0) +random_state = 42 # Parameters features = [ @@ -50,7 +50,7 @@ "GarageArea", ] alpha = 0.05 -rf_kwargs = {"n_estimators": 10, "random_state": 0} +rf_kwargs = {"n_estimators": 10, "random_state": random_state} model = RandomForestRegressor(**rf_kwargs) ############################################################################## @@ -66,7 +66,7 @@ X, y = fetch_openml(name="house_prices", return_X_y=True) X_train, X_test, y_train, y_test = train_test_split( - X[features], y, test_size=0.2 + X[features], y, test_size=0.2, random_state=random_state ) ############################################################################## @@ -87,9 +87,11 @@ ############################################################################## # First, train model with # :class:`~mapie.conformity_scores.AbsoluteConformityScore`. -mapie = MapieRegressor(model) +mapie = MapieRegressor(model, random_state=random_state) mapie.fit(X_train, y_train) -y_pred_absconfscore, y_pis_absconfscore = mapie.predict(X_test, alpha=alpha) +y_pred_absconfscore, y_pis_absconfscore = mapie.predict( + X_test, alpha=alpha, ensemble=True +) coverage_absconfscore = regression_coverage_score( y_test, y_pis_absconfscore[:, 0, 0], y_pis_absconfscore[:, 1, 0] @@ -118,10 +120,12 @@ def get_yerr(y_pred, y_pis): ############################################################################## # Then, train the model with # :class:`~mapie.conformity_scores.GammaConformityScore`. -mapie = MapieRegressor(model, conformity_score=GammaConformityScore()) +mapie = MapieRegressor( + model, conformity_score=GammaConformityScore(), random_state=random_state +) mapie.fit(X_train, y_train) y_pred_gammaconfscore, y_pis_gammaconfscore = mapie.predict( - X_test, alpha=[alpha] + X_test, alpha=[alpha], ensemble=True ) coverage_gammaconfscore = regression_coverage_score( From bda371a698531afbe7bf9b42c7010ffd8cb51048 Mon Sep 17 00:00:00 2001 From: Candice Moyet Date: Wed, 12 Jul 2023 09:53:07 +0200 Subject: [PATCH 28/30] UPD: add parameter's description in docstring --- mapie/estimator/estimator.py | 4 +++- mapie/estimator/interface.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/mapie/estimator/estimator.py b/mapie/estimator/estimator.py index ed05b850..f96a3221 100644 --- a/mapie/estimator/estimator.py +++ b/mapie/estimator/estimator.py @@ -470,7 +470,9 @@ def predict( By default ``False``. return_multi_pred: bool - + If ``True`` the method returns the predictions and the multiple + predictions (3 arrays). If ``False`` the method return the + simple predictions only. Returns ------- diff --git a/mapie/estimator/interface.py b/mapie/estimator/interface.py index dafbb382..468f3318 100644 --- a/mapie/estimator/interface.py +++ b/mapie/estimator/interface.py @@ -73,7 +73,9 @@ def predict( By default ``False``. return_multi_pred: bool - + If ``True`` the method returns the predictions and the multiple + predictions (3 arrays). If ``False`` the method return the + simple predictions only. Returns ------- From b2678fbd665bc5d6871fa7bddde7ef78580527fb Mon Sep 17 00:00:00 2001 From: Candice Moyet <62012974+candicemyt@users.noreply.github.com> Date: Wed, 12 Jul 2023 10:05:10 +0200 Subject: [PATCH 29/30] UPD: add if condition to avoid useless computations Co-authored-by: Thibault Cordier <124613154+thibaultcordier@users.noreply.github.com> --- mapie/estimator/estimator.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/mapie/estimator/estimator.py b/mapie/estimator/estimator.py index f96a3221..3e892440 100644 --- a/mapie/estimator/estimator.py +++ b/mapie/estimator/estimator.py @@ -484,7 +484,9 @@ def predict( check_is_fitted(self, self.fit_attributes) y_pred = self.single_estimator_.predict(X) - + if not return_multi_pred and not ensemble: + return y_pred + if self.method in self.no_agg_methods_ or self.cv in self.no_agg_cv_: y_pred_multi_low = y_pred[:, np.newaxis] y_pred_multi_up = y_pred[:, np.newaxis] From 45eee11a3952682d0a9e6b23d49628346ec6bf50 Mon Sep 17 00:00:00 2001 From: Candice Moyet Date: Wed, 12 Jul 2023 10:38:56 +0200 Subject: [PATCH 30/30] ENH: new tests for new if condition in estimator --- mapie/estimator/estimator.py | 2 +- mapie/tests/test_regression.py | 24 ++++++++++++++++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/mapie/estimator/estimator.py b/mapie/estimator/estimator.py index 3e892440..33bda7b3 100644 --- a/mapie/estimator/estimator.py +++ b/mapie/estimator/estimator.py @@ -486,7 +486,7 @@ def predict( y_pred = self.single_estimator_.predict(X) if not return_multi_pred and not ensemble: return y_pred - + if self.method in self.no_agg_methods_ or self.cv in self.no_agg_cv_: y_pred_multi_low = y_pred[:, np.newaxis] y_pred_multi_up = y_pred[:, np.newaxis] diff --git a/mapie/tests/test_regression.py b/mapie/tests/test_regression.py index e5a8353e..4a9d9975 100644 --- a/mapie/tests/test_regression.py +++ b/mapie/tests/test_regression.py @@ -588,3 +588,27 @@ def test_conformity_score( ) mapie_reg.fit(X, y + 1e3) mapie_reg.predict(X, alpha=0.05) + + +@pytest.mark.parametrize("ensemble", [True, False]) +def test_return_only_ypred(ensemble: bool) -> None: + """Test that if return_multi_pred is False it only returns y_pred.""" + mapie_reg = MapieRegressor() + mapie_reg.fit(X_toy, y_toy) + output = mapie_reg.estimator_.predict( + X_toy, ensemble=ensemble, return_multi_pred=False + ) + assert len(output) == len(X_toy) + + +@pytest.mark.parametrize("ensemble", [True, False]) +def test_return_multi_pred(ensemble: bool) -> None: + """ + Test that if return_multi_pred is True it returns y_pred and multi_pred. + """ + mapie_reg = MapieRegressor() + mapie_reg.fit(X_toy, y_toy) + output = mapie_reg.estimator_.predict( + X_toy, ensemble=ensemble, return_multi_pred=True + ) + assert len(output) == 3