diff --git a/benchmarks/two_moons_multicalibrate.py b/benchmarks/two_moons_multicalibrate.py new file mode 100644 index 00000000..baf5205c --- /dev/null +++ b/benchmarks/two_moons_multicalibrate.py @@ -0,0 +1,150 @@ +import flax.linen as nn +import jax.numpy as jnp +import matplotlib.pyplot as plt +import numpy as np +import optax +from sklearn.datasets import make_moons + +from fortuna.conformal import Multicalibrator +from fortuna.data import ( + DataLoader, + InputsLoader, +) +from fortuna.metric.classification import accuracy +from fortuna.model.mlp import MLP +from fortuna.prob_model import ( + CalibConfig, + CalibMonitor, + FitConfig, + FitMonitor, + FitOptimizer, + MAPPosteriorApproximator, + ProbClassifier, +) + +train_data = make_moons(n_samples=5000, noise=0.07, random_state=0) +val_data = make_moons(n_samples=1000, noise=0.07, random_state=1) +test_data = make_moons(n_samples=1000, noise=0.07, random_state=2) + +train_data_loader = DataLoader.from_array_data( + train_data, batch_size=128, shuffle=True, prefetch=True +) +val_data_loader = DataLoader.from_array_data(val_data, batch_size=128, prefetch=True) +test_data_loader = DataLoader.from_array_data(test_data, batch_size=128, prefetch=True) + +output_dim = 2 +prob_model = ProbClassifier( + model=MLP(output_dim=output_dim, activations=(nn.tanh, nn.tanh)), + posterior_approximator=MAPPosteriorApproximator(), +) + +status = prob_model.train( + train_data_loader=train_data_loader, + val_data_loader=val_data_loader, + calib_data_loader=val_data_loader, + fit_config=FitConfig( + monitor=FitMonitor(metrics=(accuracy,), early_stopping_patience=10), + optimizer=FitOptimizer(method=optax.adam(1e-4), n_epochs=10), + ), + calib_config=CalibConfig(monitor=CalibMonitor(early_stopping_patience=2)), +) + +test_inputs_loader = test_data_loader.to_inputs_loader() +test_means = prob_model.predictive.mean(inputs_loader=test_inputs_loader) +test_modes = prob_model.predictive.mode( + inputs_loader=test_inputs_loader, means=test_means +) + +fig = plt.figure(figsize=(6, 3)) +size = 150 +xx = np.linspace(-4, 4, size) +yy = np.linspace(-4, 4, size) +grid = np.array([[_xx, _yy] for _xx in xx for _yy in yy]) +grid_loader = InputsLoader.from_array_inputs(grid) +grid_entropies = prob_model.predictive.entropy(grid_loader).reshape(size, size) +grid = grid.reshape(size, size, 2) +plt.title("Predictions and entropy", fontsize=12) +im = plt.pcolor(grid[:, :, 0], grid[:, :, 1], grid_entropies) +plt.scatter( + test_data[0][:, 0], + test_data[0][:, 1], + s=1, + c=["C0" if i == 1 else "C1" for i in test_modes], +) +plt.colorbar() +plt.show() + +val_inputs_loader = val_data_loader.to_inputs_loader() +test_inputs_loader = test_data_loader.to_inputs_loader() +val_targets = val_data_loader.to_array_targets() +test_targets = test_data_loader.to_array_targets() + +val_means = prob_model.predictive.mean(val_inputs_loader) +test_means = prob_model.predictive.mean(val_inputs_loader) + +mc = Multicalibrator() +scores = val_targets +test_scores = test_targets +groups = jnp.stack((val_means.argmax(1) == 0, val_means.argmax(1) == 1), axis=1) +test_groups = jnp.stack((test_means.argmax(1) == 0, test_means.argmax(1) == 1), axis=1) +values = val_means[:, 1] +test_values = test_means[:, 1] +calib_test_values, status = mc.calibrate( + scores=scores, + groups=groups, + values=values, + test_groups=test_groups, + test_values=test_values, + n_buckets=1000, +) + +plt.figure(figsize=(10, 3)) +plt.suptitle("Multivalid calibration of probability that Y=1") +plt.subplot(1, 3, 1) +plt.title("all test inputs") +plt.hist([test_values, calib_test_values])[-1] +plt.legend(["before calibration", "after calibration"]) +plt.xlabel("prob") +plt.subplot(1, 3, 2) +plt.title("inputs for which we predict 0") +plt.hist([test_values[test_groups[:, 0]], calib_test_values[test_groups[:, 0]]])[-1] +plt.xlabel("prob") +plt.subplot(1, 3, 3) +plt.title("inputs for which we predict 1") +plt.hist([test_values[test_groups[:, 1]], calib_test_values[test_groups[:, 1]]])[-1] +plt.xlabel("prob") +plt.tight_layout() +plt.show() + +plt.title("Max calibration error decay during calibration") +plt.semilogy(status["max_calib_errors"]) +plt.show() + +print( + "Per-group reweighted avg. squared calib. error before calibration: ", + mc.calibration_error( + scores=test_scores, groups=test_groups, values=test_means.max(1) + ), +) +print( + "Per-group reweighted avg. squared calib. error after calibration: ", + mc.calibration_error( + scores=test_scores, groups=test_groups, values=calib_test_values + ), +) + +print( + "Mismatch between labels and probs before calibration: ", + jnp.mean( + jnp.maximum((1 - test_targets) * test_values, test_targets * (1 - test_values)) + ), +) +print( + "Mismatch between labels and probs after calibration: ", + jnp.mean( + jnp.maximum( + (1 - test_targets) * calib_test_values, + test_targets * (1 - calib_test_values), + ) + ), +) diff --git a/docs/source/methods.rst b/docs/source/methods.rst index a576a10a..560eaeae 100644 --- a/docs/source/methods.rst +++ b/docs/source/methods.rst @@ -62,9 +62,14 @@ For classification: sequential prediction framework (e.g. time series forecasting) when the distribution of the data shifts over time. - **BatchMVP** `[Jung C. et al., 2022] `_ - a conformal prediction algorithm that satisfies coverage guarantees conditioned on group membership and + A conformal prediction algorithm that satisfies coverage guarantees conditioned on group membership and non-conformity thresholds. +- **Multicalibrate** `[Hébert-Johnson Ú. et al., 2017] `_, `[Roth A., Algorithm 15] `_ + Unlike standard conformal prediction methods, this algorithm returns scalar calibrated score values for each data point. + For example, in binary classification, it can return calibrated probabilities of predictions. + This method satisfies coverage guarantees conditioned on group membership and non-conformity thresholds. + For regression: - **Conformalized quantile regression** `[Romano et al., 2019] `_ @@ -79,12 +84,16 @@ For regression: satisfying minimal coverage properties. - **BatchMVP** `[Jung C. et al., 2022] `_ - a conformal prediction algorithm that satisfies coverage guarantees conditioned on group membership and + A conformal prediction algorithm that satisfies coverage guarantees conditioned on group membership and non-conformity thresholds. - **EnbPI** `[Xu et al., 2021] `_ A conformal prediction method for time series regression based on data bootstrapping. +- **Multicalibrate** `[Hébert-Johnson Ú. et al., 2017] `_, `[Roth A., Algorithm 15] `_ + Unlike standard conformal prediction methods, this algorithm returns scalar calibrated score values for each data point. + This method satisfies coverage guarantees conditioned on group membership and non-conformity thresholds. + - **Adaptive conformal inference** `[Gibbs et al., 2021] `_ A method for conformal prediction that aims at correcting the coverage of conformal prediction methods in a sequential prediction framework (e.g. time series forecasting) when the distribution of the data shifts over time. diff --git a/docs/source/references/conformal.rst b/docs/source/references/conformal.rst index efa7f0d6..131f90b2 100644 --- a/docs/source/references/conformal.rst +++ b/docs/source/references/conformal.rst @@ -16,6 +16,8 @@ and :ref:`regression `. .. automodule:: fortuna.conformal.classification.batch_mvp +.. automodule:: fortuna.conformal.classification.multicalibrator + .. _conformal_regression: .. automodule:: fortuna.conformal.regression.quantile @@ -33,3 +35,5 @@ and :ref:`regression `. .. automodule:: fortuna.conformal.regression.adaptive_conformal_regressor .. automodule:: fortuna.conformal.regression.batch_mvp + +.. automodule:: fortuna.conformal.regression.multicalibrator diff --git a/examples/multivalid_coverage.pct.py b/examples/multivalid_coverage.pct.py index f98914ab..241d0b8e 100644 --- a/examples/multivalid_coverage.pct.py +++ b/examples/multivalid_coverage.pct.py @@ -206,11 +206,9 @@ def plot_intervals(xx, means, intervals, test_data, method): # %% [markdown] # We finally introduce Batch MVP [[Jung C. et al., 2022]](https://arxiv.org/pdf/2209.15145.pdf) and show that it improves group-conditional coverage. For its usage, we require: # -# - a valid non-conformity score function. This can be any score function measuring the degree of non-conformity between inputs $x$ and targets $y$. The less $x$ and $y$ conform with each other, the larger the score should be. A simple example of score function in regression is $s(x,y)=|y - h(x)|$, where $h$ is an arbitrary model. For the purpose of this example, we use the same score function as in CQR, that is $s(x,y)=\max\left(q_{\frac{\alpha}{2}} - y, y - q_{1 - \frac{\alpha}{2}}\right)$, where $\alpha$ is the desired coverage error, i.e. $\alpha=0.05$, and $q_\alpha$ is a corresponding quantile. +# - non-conformity scores evaluated on calibration. These can be evaluations of any score function measuring the degree of non-conformity between inputs $x$ and targets $y$. The less $x$ and $y$ conform with each other, the larger the score should be. A simple example of score function in regression is $s(x,y)=|y - h(x)|$, where $h$ is an arbitrary model. For the purpose of this example, we use the same score function as in CQR, that is $s(x,y)=\max\left(q_{\frac{\alpha}{2}} - y, y - q_{1 - \frac{\alpha}{2}}\right)$, where $\alpha$ is the desired coverage error, i.e. $\alpha=0.05$, and $q_\alpha$ is a corresponding quantile. # -# - the group functions. These construct sub-domains of interest of the input domain. As we defined above, here we use $g_1(x) = \mathbb{1}[x < 0]$ and $g_2(x) = \mathbb{1}[x \ge 0]$. -# -# - the bounds function. This is a function $b(x, \tau)$ that simultaneously defines the lower and upper bounds of the conformal interval given an input $x$ and a threshold $\tau$. For example, for the score function in use, we have $b(x, \tau) = [q_{\frac{\alpha}{2}} - \tau, q_{1 - \frac{\alpha}{2}} + \tau]$. Please notice that the bounds function is related to the inverse score function with respect to $y$. In fact, it defines the two extreme values of $y$ that satisfy the relation $s(x, y) \le \tau$. +# - group evaluations on calibration and test data. These construct sub-domains of interest of the input domain. As we defined above, here we use $g_1(x) = \mathbb{1}[x < 0]$ and $g_2(x) = \mathbb{1}[x \ge 0]$. # # That's it! Defined these, we are ready to run `BatchMVP`. @@ -218,36 +216,42 @@ def plot_intervals(xx, means, intervals, test_data, method): from fortuna.conformal.regression.batch_mvp import BatchMVPConformalRegressor import jax.numpy as jnp - -def score_fn(x, y): - qleft, qright = prob_model.predictive.quantile( - [0.05 / 2, 1 - 0.05 / 2], InputsLoader.from_array_inputs(x) - ) - return jnp.maximum(qleft - y, y - qright).squeeze(1) - - -def bounds_fn(x, t): - qleft, qright = prob_model.predictive.quantile( - [0.05 / 2, 1 - 0.05 / 2], InputsLoader.from_array_inputs(x) - ) - return qleft.squeeze(1) - t, qright.squeeze(1) + t - - -batchmvp = BatchMVPConformalRegressor( - score_fn=score_fn, group_fns=group_fns, bounds_fn=bounds_fn +qleft, qright = prob_model.predictive.quantile( + [0.05 / 2, 1 - 0.05 / 2], calib_inputs_loader ) -test_batchmvp_intervals, max_calib_errors = batchmvp.conformal_interval( - calib_data_loader, test_inputs_loader, return_max_calib_error=True +scores = jnp.maximum(qleft - calib_targets, calib_targets - qright).squeeze(1) +min_score, max_score = scores.min(), scores.max() +scores = (scores - min_score) / (max_score - min_score) +groups = jnp.stack([g(calib_data[0]) for g in group_fns], axis=1) +test_groups = jnp.stack([g(test_data[0]) for g in group_fns], axis=1) + +batchmvp = BatchMVPConformalRegressor() +test_thresholds, status = batchmvp.calibrate( + scores=scores, groups=groups, test_groups=test_groups ) +test_thresholds = min_score + (max_score - min_score) * test_thresholds # %% [markdown] # At each iteration, `BatchMVP` we compute the maximum calibration error over the different groups. We report its decay in the following picture. # %% plt.figure(figsize=(6, 3)) -plt.plot(max_calib_errors, label="maximum calibration error decay") +plt.plot(status["max_calib_errors"], label="maximum calibration error decay") plt.xlabel("rounds") plt.legend() +plt.show() + +# %% [markdown] +# Given the test thresholds, we can find the lower and upper bounds of the conformal intervals by inverting the score function $s(x, y)$ with respect to $y$. This gives $b(x, \tau) = [q_{\frac{\alpha}{2}}(x) - \tau, q_{1 - \frac{\alpha}{2}}(x) + \tau]$, where $\tau$ denotes the thresholds. + +# %% +test_qleft, test_qright = prob_model.predictive.quantile( + [0.05 / 2, 1 - 0.05 / 2], test_inputs_loader +) +test_qleft, test_qright = test_qleft.squeeze(1), test_qright.squeeze(1) +test_batchmvp_intervals = jnp.stack( + (test_qleft - test_thresholds, test_qright + test_thresholds), axis=1 +) # %% [markdown] # We now compute coverage metrics. As expected, `BatchMVP` not only provides a good marginal coverage overall, but also improves coverage on both negative and positive inputs. @@ -268,5 +272,14 @@ def bounds_fn(x, t): # Once again, we visualize predictions and estimated intervals. # %% -xx_batchmvp_intervals = batchmvp.conformal_interval(calib_data_loader, xx_loader) +xx_qleft, xx_qright = prob_model.predictive.quantile( + [0.05 / 2, 1 - 0.05 / 2], InputsLoader.from_array_inputs(xx) +) +xx_qleft, xx_qright = xx_qleft.squeeze(1), xx_qright.squeeze(1) +xx_groups = jnp.stack([g(xx) for g in group_fns], axis=1) +xx_thresholds = batchmvp.apply_patches(groups=xx_groups) +xx_thresholds = min_score + (max_score - min_score) * xx_thresholds +xx_batchmvp_intervals = jnp.stack( + (xx_qleft - xx_thresholds, xx_qright + xx_thresholds), axis=1 +) plot_intervals(xx, xx_means, xx_batchmvp_intervals, test_data, "BatchMVP") diff --git a/fortuna/conformal/__init__.py b/fortuna/conformal/__init__.py index 5ff40f8e..3118c74b 100644 --- a/fortuna/conformal/__init__.py +++ b/fortuna/conformal/__init__.py @@ -9,6 +9,7 @@ from fortuna.conformal.classification.simple_prediction import ( SimplePredictionConformalClassifier, ) +from fortuna.conformal.multivalid.multicalibrator import Multicalibrator from fortuna.conformal.regression.adaptive_conformal_regressor import ( AdaptiveConformalRegressor, ) diff --git a/fortuna/conformal/batch_mvp.py b/fortuna/conformal/batch_mvp.py deleted file mode 100644 index 1345c5de..00000000 --- a/fortuna/conformal/batch_mvp.py +++ /dev/null @@ -1,209 +0,0 @@ -from __future__ import annotations - -import abc -import logging -from typing import ( - Callable, - List, - Tuple, - Union, -) - -from jax import vmap -import jax.numpy as jnp - -from fortuna.data.loader import ( - DataLoader, - InputsLoader, -) -from fortuna.typing import Array - - -class Group: - def __init__(self, group_fn: Callable[[Array], Array]): - self.group_fn = group_fn - - def __call__(self, x): - g = self.group_fn(x) - if g.ndim > 1: - raise ValueError( - "Evaluations of the group function `group_fn` must be one-dimensional arrays." - ) - if jnp.any((g != 0) * (g != 1)): - raise ValueError( - "The group function `threshold_fn` must take values in {0, 1}." - ) - return g.astype(bool) - - -class Normalizer: - def __init__(self, xmin: Array, xmax: Array): - self.xmin = xmin - self.xmax = xmax if xmax != xmin else xmin + 1 - - def normalize(self, x: Array) -> Array: - return (x - self.xmin) / (self.xmax - self.xmin) - - def unnormalize(self, y: Array) -> Array: - return self.xmin + (self.xmax - self.xmin) * y - - -class Score: - def __init__(self, score_fn: Callable[[Array, Array], Array]): - self.score_fn = score_fn - - def __call__(self, x: Array, y: Array): - s = self.score_fn(x, y) - if s.ndim > 1: - raise ValueError( - "Evaluations of the score function `score_fn` must be one-dimensional arrays, " - f"but its shape was {s.shape}." - ) - return s - - -class BatchMVPConformalMethod(abc.ABC): - def __init__( - self, - score_fn: Callable[[Array, Array], Array], - group_fns: List[Callable[[Array], Array]], - n_buckets: int = 100, - ): - super().__init__() - self.score_fn = Score(score_fn) - self.group_fns = [Group(g) for g in group_fns] - self.buckets = jnp.linspace(0, 1, n_buckets + 1) - self.n_buckets = n_buckets + 1 - - def threshold_score( - self, - val_data_loader: DataLoader, - test_inputs_loader: InputsLoader, - error: float = 0.05, - tol: float = 1e-4, - n_rounds: int = 1000, - return_max_calib_error: bool = False, - ) -> Union[Array, Tuple[Array, List[Array]]]: - """ - Compute a threshold :math:`f(x)` of the score functions :math:`s(x,y)` for each test input :math:`x`. - Given these threshold, conformal sets can be formulated as :math:`C(x) = \{y: s(x,y) \le f(x)\}`. - - Parameters - ---------- - val_data_loader: DataLoader - A data loader of validation data. - test_inputs_loader: InputsLoader - A loader of test input data points. - error: float - A desired coverage error. - tol: float - A tolerance for the maximum calibration error. - n_rounds: int - The maximum number of updates the algorithm will run for. - return_max_calib_error: bool - Whether to return a list of computed maximum calibration error, that is the larger calibration error - over the different groups. - - Returns - ------- - Union[Array, Tuple[Array, List[Array]]] - The compute threshold of the score function for each test input. - """ - quantile = 1 - error - - scores, thresholds, groups = [], [], [] - for inputs, targets in val_data_loader: - scores.append(self.score_fn(inputs, targets)) - thresholds.append(jnp.zeros(inputs.shape[0])) - groups.append( - jnp.concatenate([g(inputs)[:, None] for g in self.group_fns], axis=1) - ) - scores, thresholds, groups = ( - jnp.concatenate(scores), - jnp.concatenate(thresholds), - jnp.concatenate(groups, 0), - ) - - test_thresholds, test_groups = [], [] - for inputs in test_inputs_loader: - test_thresholds.append(jnp.zeros(inputs.shape[0])) - test_groups.append( - jnp.concatenate([g(inputs)[:, None] for g in self.group_fns], axis=-1) - ) - test_thresholds, test_groups = jnp.concatenate( - test_thresholds - ), jnp.concatenate(test_groups, 0) - - normalizer = Normalizer(jnp.min(scores), jnp.max(scores)) - scores = normalizer.normalize(scores) - - n_groups = groups.shape[1] - - def compute_probability_error( - v: Array, g: Array, delta: Union[Array, float] = 0.0 - ): - b = (jnp.abs(thresholds - v) < 0.5 / self.n_buckets) * groups[:, g] - filtered_scores = jnp.where(b, scores, -jnp.ones_like(scores)) - conds = (filtered_scores <= v + delta) * (filtered_scores != -1) - prob_b = jnp.mean(b) - prob = jnp.where(prob_b > 0, jnp.mean(conds) / prob_b, 0.0) - return (quantile - prob) ** 2 - - def calibration_error(v, g): - b = (jnp.abs(thresholds - v) < 0.5 / self.n_buckets) * groups[:, g] - filtered_scores = jnp.where(b, scores, -jnp.ones_like(scores)) - conds = (filtered_scores <= v) * (filtered_scores != -1) - prob_b = jnp.mean(b) - prob = jnp.where(prob_b > 0, jnp.mean(conds) / prob_b, 0.0) - return prob_b * (quantile - prob) ** 2 - - max_calib_errors = None - if return_max_calib_error: - max_calib_errors = [] - - for t in range(n_rounds): - calib_error_vg = vmap( - lambda g: vmap(lambda v: calibration_error(v, g))(self.buckets) - )(jnp.arange(n_groups)) - max_calib_error = calib_error_vg.sum(1).max() - if return_max_calib_error: - max_calib_errors.append(max_calib_error) - if max_calib_error <= tol: - logging.info( - f"The algorithm produced a {tol}-approximately {quantile}-quantile multicalibrated " - f"threshold function after {t} rounds." - ) - break - - gt, idx_vt = jnp.unravel_index( - jnp.argmax(calib_error_vg), (n_groups, self.n_buckets) - ) - vt = self.buckets[idx_vt] - - deltat = self.buckets[ - jnp.argmin( - jnp.abs( - vmap(lambda delta: compute_probability_error(vt, gt, delta))( - self.buckets - ) - ) - ) - ] - bt = (jnp.abs(thresholds - vt) < 0.5 / self.n_buckets) * groups[:, gt] - thresholds = thresholds.at[bt].set( - jnp.minimum(thresholds[bt] + deltat, jnp.ones_like(thresholds[bt])) - ) - test_bt = ( - jnp.abs(test_thresholds - vt) < 0.5 / self.n_buckets - ) * test_groups[:, gt] - test_thresholds = test_thresholds.at[test_bt].set( - jnp.minimum( - test_thresholds[test_bt] + deltat, - jnp.ones_like(test_thresholds[test_bt]), - ) - ) - - test_thresholds = normalizer.unnormalize(test_thresholds) - if return_max_calib_error: - return test_thresholds, max_calib_errors - return test_thresholds diff --git a/fortuna/conformal/classification/batch_mvp.py b/fortuna/conformal/classification/batch_mvp.py index eaaf1690..52319b1c 100644 --- a/fortuna/conformal/classification/batch_mvp.py +++ b/fortuna/conformal/classification/batch_mvp.py @@ -1,137 +1,66 @@ -from typing import ( - Callable, - List, - Optional, - Tuple, - Union, -) +from typing import List -from jax import vmap import jax.numpy as jnp import numpy as np -from fortuna.conformal.batch_mvp import BatchMVPConformalMethod from fortuna.conformal.classification.base import ConformalClassifier -from fortuna.data.loader import ( - DataLoader, - InputsLoader, -) +from fortuna.conformal.multivalid.batch_mvp import BatchMVPConformalMethod from fortuna.typing import Array class BatchMVPConformalClassifier(BatchMVPConformalMethod, ConformalClassifier): def __init__( self, - score_fn: Callable[[Array, Array], Array], - group_fns: List[Callable[[Array], Array]], - n_classes: int, - n_buckets: int = 100, ): """ This class implements a classification version of BatchMVP `[Jung et al., 2022] `_, a multivalid conformal prediction method that satisfies coverage guarantees conditioned on group membership and non-conformity threshold. - - Parameters - ---------- - score_fn: Callable[[Array, Array], Array] - A score function mapping a batch of inputs and targets to scalar scores, one for each data point. The - score function represents the degree of non-conformity between inputs and targets. In regression, an - example of score function is :math:`s(x,y)=|y - h(x)|`, where `h` is an arbitrary regression model. - group_fns: List[Callable[[Array], Array]] - A list of group functions, each mapping input data points into boolean arrays which determine whether - an input belongs to a certain group or not. As an example, suppose that we are interested in obtaining - marginal coverages guarantee on both negative and positive scalar inputs. - Then we could define groups functions - :math:`g_1(x) = x < 0` and :math:`g_1(x) = x > 0`. - Note that groups can be overlapping, and do not need to cover the full space of inputs. - n_classes: int - The number of distinct classes to classify among. The underlying assumption is that the classes are - identified with an integer from 0 to :code:`n_classes-1`. - n_buckets: int - The number of buckets that defines the search space between 0 and 1 that determines the updates of the - thresholds for the score function. """ - super().__init__(score_fn=score_fn, group_fns=group_fns, n_buckets=n_buckets) - self.n_classes = n_classes + super().__init__() def conformal_set( self, - val_data_loader: DataLoader, - test_inputs_loader: InputsLoader, - error: float = 0.05, - tol: float = 1e-4, - n_rounds: int = 1000, - return_max_calib_error: bool = False, - test_thresholds: Optional[Array] = None, - ) -> Union[List[List[int]], Tuple[List[List[int]], List[Array]]]: + class_scores: Array, + values: Array, + ) -> List[List[int]]: """ - Compute a conformal set for each test input. + Compute a conformal set for each input. Parameters ---------- - val_data_loader: DataLoader - A data loader of validation data. - test_inputs_loader: InputsLoader - A loader of test input data points. - error: float - A desired coverage error. - tol: float - A tolerance for the maximum calibration error. - n_rounds: int - The maximum number of updates the algorithm will run for. - return_max_calib_error: bool - Whether to return a list of computed maximum calibration error, that is the larger calibration error - over the different groups. - test_thresholds: Optional[Array] - The score thresholds computed over the test data set. These should be the output of - `BatchMVP.threshold_score`. If provided, they will not be recomputed internally. + class_scores: Array + A two-dimensional array of scores. The first dimension is over the different inputs. + The second dimension is over all the possible classes. For example, if there are 10 classes, + the first row of `class_scores` show be :math:`[s(x_1, 0), \dots, s(x_1, 9)]`. + values: Array + A one-dimensional array of values over the different inputs. This should be obtained from the `calibrate` + method. Returns ------- - Union[List[List[int]], Tuple[List[List[int]], List[Array]]] - The computed conformal sets for each test input. Optionally, it returns the maximum calibration errors - computed during the algorithm. + List[List[int]] + Conformal sets for each input data point. """ - if test_thresholds is not None and return_max_calib_error: + if class_scores.ndim != 2: raise ValueError( - "If `test_thresholds` is given, `return_max_calib_error` cannot be returned." - ) - if test_thresholds is None: - outs = self.threshold_score( - val_data_loader=val_data_loader, - test_inputs_loader=test_inputs_loader, - error=error, - tol=tol, - n_rounds=n_rounds, - return_max_calib_error=return_max_calib_error, + "`class_scores` must bse a 2-dimensional array. " + "The first dimension is over the different inputs. " + "The second dimension is over all the possible classes." ) - if return_max_calib_error: - test_thresholds, max_calib_errors = outs - else: - test_thresholds = outs - - c = 0 - all_ys = jnp.arange(self.n_classes) - all_bools = [] - for inputs in test_inputs_loader: - batch_thresholds = test_thresholds[c : c + inputs.shape[0]] - all_bools.append( - vmap( - lambda y: self.score_fn(inputs, y) <= batch_thresholds, out_axes=1 - )(all_ys) + if values.ndim != 1: + raise ValueError("`values` must be a 1-dimensional array.") + if class_scores.shape[0] != values.shape[0]: + raise ValueError( + "The first dimension of `class_scores` and `values` must be over the same input data " + "points." ) - c += inputs.shape[0] - all_bools = jnp.concatenate(all_bools, axis=0) + bools = class_scores <= values[:, None] - sizes = np.sum(all_bools, 1) - sets = np.zeros(c, dtype=object) + sizes = np.sum(bools, 0) + sets = np.zeros(bools.shape[0], dtype=object) for s in np.unique(sizes): idx = jnp.where(sizes == s)[0] - sets[idx] = np.nonzero(all_bools[idx])[1].reshape(len(idx), s).tolist() - sets = sets.tolist() - - if return_max_calib_error: - return sets, max_calib_errors - return sets + sets[idx] = np.nonzero(bools[idx])[1].reshape(len(idx), s).tolist() + return sets.tolist() diff --git a/fortuna/conformal/multivalid/__init__.py b/fortuna/conformal/multivalid/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/fortuna/conformal/multivalid/base.py b/fortuna/conformal/multivalid/base.py new file mode 100644 index 00000000..b1207170 --- /dev/null +++ b/fortuna/conformal/multivalid/base.py @@ -0,0 +1,374 @@ +import abc +import logging +from typing import ( + Callable, + Dict, + List, + Optional, + Tuple, + Union, +) + +from jax import vmap +import jax.numpy as jnp + +from fortuna.data.loader import ( + DataLoader, + InputsLoader, +) +from fortuna.typing import Array + + +class Normalizer: + def __init__(self, x_min: Union[float, Array], x_max: Union[float, Array]): + self.x_min = x_min + self.x_max = x_max if x_max != x_min else x_min + 1 + + def normalize(self, x: Array) -> Array: + return (x - self.x_min) / (self.x_max - self.x_min) + + def unnormalize(self, y: Array) -> Array: + return self.x_min + (self.x_max - self.x_min) * y + + +class Model: + def __init__(self, model_fn: Callable[[Array], Array]): + self.model_fn = model_fn + + def __call__(self, x: Array): + v = self.model_fn(x) + if v.ndim > 1: + raise ValueError( + "Evaluations of the model function `model_fn` must be one-dimensional arrays, " + f"but its shape was {v.shape}." + ) + if jnp.any(v < 0) or jnp.any(v > 1): + raise ValueError("The model function must take values within [0, 1].") + return v + + +class MultivalidMethod: + def __init__(self): + self._patches = [] + self._n_buckets = None + + def calibrate( + self, + scores: Array, + groups: Array, + values: Optional[Array] = None, + test_groups: Optional[Array] = None, + test_values: Optional[Array] = None, + tol: float = 1e-4, + n_buckets: int = None, + n_rounds: int = 1000, + **kwargs, + ) -> Union[Dict, Tuple[Array, Dict]]: + """ + Calibrate the model by finding a list of patches to the model that bring the calibration error below a + certain threshold. + + Parameters + ---------- + scores: Array + A list of scores :math:`s(x, y)` computed on the calibration data. + This should be a one-dimensional array of elements between 0 and 1. + groups: Array + A list of groups :math:`g(x)` computed on the calibration data. + This should be a two-dimensional array of bool elements. + The first dimension is over the data points, the second dimension is over the number of groups. + values: Optional[Array] + The initial model evalutions :math:`f(x)` on the calibration data. If not provided, these are set to 0. + test_groups: Optional[Array] + A list of groups :math:`g(x)` computed on the test data. + This should be a two-dimensional array of bool elements. + The first dimension is over the data points, the second dimension is over the number of groups. + test_values: Optional[Array] + The initial model evaluations :math:`f(x)` on the test data. If not provided, these are set to 0. + tol: float + A tolerance on the reweighted average squared calibration error, i.e. :math:`\mu(g) K_2(f, g, \mathcal{D})`. + n_buckets: int + The number of buckets used in the algorithm. The smaller the number of buckets, the simpler the model, + the better its generalization abilities. If not provided, We start from 2 buckets, and progressively double + the number of buckets until we find a value for which the calibration error falls below the given + tolerance. Such number of buckets is guaranteed to exist. + n_rounds: int + The maximum number of rounds to run the method for. + Returns + ------- + Union[Dict, Tuple[Array, Dict]] + A status including the number of rounds taken to reach convergence and the calibration errors computed + during the training procedure. if `test_values` and `test_groups` are provided, the list of patches will + be applied to `test_values`, and the calibrated test values will be returned together with the status. + """ + if tol >= 1: + raise ValueError("`tol` must be smaller than 1.") + if n_rounds < 1: + raise ValueError("`n_rounds` must be at least 1.") + + if test_groups is None and test_values is not None: + raise ValueError( + "If `test_values` is provided, `test_groups` must be also provided." + ) + if test_values is not None and values is None: + raise ValueError( + "If `test_values is provided, `values` must also be provided." + ) + if values is not None and test_groups is not None and test_values is None: + raise ValueError( + "If `values` and `test_groups` are provided, `test_values` must also be provided." + ) + + if values is None: + values_init = jnp.zeros(groups.shape[0]) + else: + values_init = jnp.copy(values) + + self._check_range(dict(scores=scores, values=values, test_values=test_values)) + + increase_n_buckets = False + if n_buckets is None: + n_buckets = 2 + increase_n_buckets = True + + n_groups = groups.shape[1] + tol_reached = False + + while True: + logging.info(f"Attempt reaching tolerance with {n_buckets} buckets.") + buckets = self._get_buckets(n_buckets) + values = vmap(lambda v: self._round_to_buckets(v, buckets))(values_init) + + max_calib_errors = [] + old_calib_errors_vg = None + self._patches = [] + + for t in range(n_rounds): + calib_error_vg = vmap( + lambda g: vmap( + lambda v: self._calibration_error( + v, + g, + scores=scores, + groups=groups, + values=values, + n_buckets=n_buckets, + **kwargs, + ) + )(buckets) + )(jnp.arange(n_groups)) + + max_calib_errors.append(calib_error_vg.sum(1).max()) + if max_calib_errors[-1] <= tol: + tol_reached = True + logging.info( + f"Tolerance satisfied after {t} rounds with {n_buckets} buckets." + ) + break + if old_calib_errors_vg is not None and jnp.allclose( + old_calib_errors_vg, calib_error_vg + ): + break + old_calib_errors_vg = jnp.copy(calib_error_vg) + + gt, vt = self._get_gt_and_vt( + calib_error_vg=calib_error_vg, buckets=buckets, n_groups=n_groups + ) + bt = self._get_b( + groups=groups, values=values, v=vt, g=gt, n_buckets=len(buckets) + ) + patch = self._get_patch( + vt=vt, + gt=gt, + scores=scores, + groups=groups, + values=values, + buckets=buckets, + **kwargs, + ) + values = self._patch(values=values, patch=patch, bt=bt) + + self._patches.append((gt, vt, patch)) + + if tol_reached: + break + if increase_n_buckets: + n_buckets *= 2 + else: + break + + self.n_buckets = n_buckets + status = dict(n_rounds=len(self.patches), max_calib_errors=max_calib_errors) + + if test_groups is not None: + test_values = self.apply_patches(test_groups, test_values) + return test_values, status + return status + + def apply_patches( + self, + groups: Array, + values: Optional[Array] = None, + ) -> Array: + """ + Apply the patches to the model evaluations. + + Parameters + ---------- + groups: Array + A list of groups :math:`g(x)` evaluated over some inputs. + This should be a two-dimensional array of bool elements. + The first dimension is over the data points, the second dimension is over the number of groups. + values: Optional[Array] + The initial model evaluations :math:`f(x)` evaluated over some inputs. If not provided, these are set to 0. + + Returns + ------- + Array + The calibrated values. + """ + if not len(self._patches): + logging.warning("No patches available.") + return values + if values is None: + values = jnp.zeros(groups.shape[0]) + + buckets = self._get_buckets(n_buckets=self.n_buckets) + values = vmap(lambda v: self._round_to_buckets(v, buckets))(values) + + for gt, vt, patch in self._patches: + bt = self._get_b( + groups=groups, values=values, v=vt, g=gt, n_buckets=self.n_buckets + ) + values = self._patch(values=values, bt=bt, patch=patch) + return values + + def calibration_error( + self, + scores: Array, + groups: Array, + values: Array, + n_buckets: int = 10000, + **kwargs, + ) -> Array: + """ + The reweighted average squared calibration error :math:`\mu(g) K_2(f, g, \mathcal{D})`. + + Parameters + ---------- + scores + groups: Array + A list of groups :math:`g(x)` evaluated over some inputs. + This should be a two-dimensional array of bool elements. + The first dimension is over the data points, the second dimension is over the number of groups. + values: Array + The model evaluations, before or after calibration. + n_buckets: int + The number of buckets used in the algorithm. + + Returns + ------- + Array + The computed calibration error for each group + """ + buckets = self._get_buckets(n_buckets) + values = vmap(lambda v: self._round_to_buckets(v, buckets))(values) + + return vmap( + lambda g: vmap( + lambda v: self._calibration_error( + v=v, + g=g, + scores=scores, + groups=groups, + values=values, + n_buckets=n_buckets, + **kwargs, + ) + )(buckets) + )(jnp.arange(groups.shape[1])).sum(1) + + @property + def patches(self): + return self._patches + + @property + def n_buckets(self): + return self._n_buckets + + @n_buckets.setter + def n_buckets(self, n_buckets): + self._n_buckets = n_buckets + + @abc.abstractmethod + def _calibration_error( + self, + v: float, + g: Array, + scores: Array, + groups: Array, + values: Array, + n_buckets: int, + **kwargs, + ): + pass + + @staticmethod + def _init_missing_model_fn(): + return Model(lambda x: jnp.zeros(x.shape[0])) + + @staticmethod + def _get_gt_and_vt( + calib_error_vg: Array, buckets: Array, n_groups: int + ) -> Tuple[Array, Array]: + gt, idx_vt = jnp.unravel_index( + jnp.argmax(calib_error_vg), (n_groups, len(buckets)) + ) + vt = buckets[idx_vt] + return gt, vt + + @staticmethod + def _get_b( + groups: Array, values: Array, v: Array, g: Array, n_buckets: int + ) -> Array: + return (jnp.abs(values - v) < 0.5 / n_buckets) * groups[:, g] + + @abc.abstractmethod + def _get_patch( + self, + vt: Array, + gt: Array, + scores: Array, + groups: Array, + values: Array, + buckets: Array, + **kwargs, + ) -> Array: + pass + + @staticmethod + def _patch(values: Array, patch: Array, bt: Array, _shift: bool = False) -> Array: + return values.at[bt].set( + jnp.minimum( + patch if not _shift else values[bt] + patch, + jnp.ones_like(values[bt]), + ) + ) + + @staticmethod + def _get_buckets(n_buckets: int): + return jnp.linspace(0, 1, n_buckets) + + @staticmethod + def _round_to_buckets(v: Array, buckets: Array): + return buckets[jnp.argmin(jnp.abs(v - buckets))] + + @staticmethod + def _check_range(d): + def _maybe_check(k, v): + if v is not None: + if v.min() < 0 or v.max() > 1: + raise ValueError(f"All elements in `{k}` must be between 0 and 1.") + + for k, v in d.items(): + _maybe_check(k, v) diff --git a/fortuna/conformal/multivalid/batch_mvp.py b/fortuna/conformal/multivalid/batch_mvp.py new file mode 100644 index 00000000..a16281ce --- /dev/null +++ b/fortuna/conformal/multivalid/batch_mvp.py @@ -0,0 +1,210 @@ +from typing import ( + Dict, + Optional, + Tuple, + Union, +) + +from jax import vmap +import jax.numpy as jnp + +from fortuna.conformal.classification.base import ConformalClassifier +from fortuna.conformal.multivalid.base import MultivalidMethod +from fortuna.typing import Array + + +class BatchMVPConformalMethod(MultivalidMethod, ConformalClassifier): + def __init__( + self, + ): + """ + This class implements a classification version of BatchMVP + `[Jung et al., 2022] `_, + a multivalid conformal prediction method that satisfies coverage guarantees conditioned on group membership + and non-conformity threshold. + """ + super().__init__() + self._coverage = None + + def calibrate( + self, + scores: Array, + groups: Array, + values: Optional[Array] = None, + test_groups: Optional[Array] = None, + test_values: Optional[Array] = None, + tol: float = 1e-4, + n_buckets: int = None, + n_rounds: int = 1000, + coverage: float = 0.95, + ) -> Union[Dict, Tuple[Array, Dict]]: + """ + Calibrate the model by finding a list of patches to the model that bring the calibration error below a + certain threshold. + + Parameters + ---------- + scores: Array + A list of scores :math:`s(x, y)` computed on the calibration data. + This should be a one-dimensional array of elements between 0 and 1. + groups: Array + A list of groups :math:`g(x)` computed on the calibration data. + This should be a two-dimensional array of bool elements. + The first dimension is over the data points, the second dimension is over the number of groups. + values: Optional[Array] + The initial model evalutions :math:`f(x)` on the calibration data. If not provided, these are set to 0. + test_groups: Optional[Array] + A list of groups :math:`g(x)` computed on the test data. + This should be a two-dimensional array of bool elements. + The first dimension is over the data points, the second dimension is over the number of groups. + test_values: Optional[Array] + The initial model evaluations :math:`f(x)` on the test data. If not provided, these are set to 0. + tol: float + A tolerance on the reweighted average squared calibration error, i.e. :math:`\mu(g) K_2(f, g, \mathcal{D})`. + n_buckets: int + The number of buckets used in the algorithm. The smaller the number of buckets, the simpler the model, + the better its generalization abilities. If not provided, We start from 2 buckets, and progressively double + the number of buckets until we find a value for which the calibration error falls below the given + tolerance. Such number of buckets is guaranteed to exist. + n_rounds: int + The maximum number of rounds to run the method for. + coverage: float + The desired level of coverage. This must be a scalar between 0 and 1. + Returns + ------- + Union[Dict, Tuple[Array, Dict]] + A status including the number of rounds taken to reach convergence and the calibration errors computed + during the training procedure. if `test_values` and `test_groups` are provided, the list of patches will + be applied to `test_values`, and the calibrated test values will be returned together with the status. + """ + if coverage < 0 or coverage > 1: + raise ValueError("`coverage` must be a float between 0 and 1.") + self._coverage = coverage + return super().calibrate( + scores=scores, + groups=groups, + values=values, + test_groups=test_groups, + test_values=test_values, + tol=tol, + n_buckets=n_buckets, + n_rounds=n_rounds, + coverage=coverage, + ) + + def calibration_error( + self, + scores: Array, + groups: Array, + values: Array, + n_buckets: int = 10000, + **kwargs, + ) -> Array: + return super().calibration_error( + scores=scores, + groups=groups, + values=values, + n_buckets=n_buckets, + coverage=self._coverage, + ) + + def _calibration_error( + self, + v: Array, + g: Array, + scores: Array, + groups: Array, + values: Array, + n_buckets: int, + coverage: float = None, + ): + prob_error, prob_b = self._compute_probability_error( + v=v, + g=g, + delta=0.0, + scores=scores, + groups=groups, + values=values, + n_buckets=n_buckets, + return_prob_b=True, + coverage=coverage, + ) + return prob_b * prob_error + + def _compute_probability_error( + self, + v: Array, + g: Array, + delta: Array, + scores: Array, + groups: Array, + values: Array, + n_buckets: int, + return_prob_b: bool = False, + coverage: float = None, + ): + prob = self._compute_probability( + v=v, + g=g, + delta=delta, + scores=scores, + groups=groups, + values=values, + n_buckets=n_buckets, + return_prob_b=return_prob_b, + ) + if return_prob_b: + prob, prob_b = prob + return (coverage - prob) ** 2, prob_b + return (coverage - prob) ** 2 + + def _compute_probability( + self, + v: Array, + g: Array, + delta: Array, + scores: Array, + groups: Array, + values: Array, + n_buckets: int, + return_prob_b: bool = False, + ): + b = self._get_b(groups=groups, values=values, v=v, g=g, n_buckets=n_buckets) + conds = (scores <= v + delta) * b + prob_b = jnp.mean(b) + prob = jnp.where(prob_b > 0, jnp.mean(conds) / prob_b, 0.0) + if return_prob_b: + return prob, prob_b + return prob + + def _get_patch( + self, + vt: Array, + gt: Array, + scores: Array, + groups: Array, + values: Array, + buckets: Array, + coverage: float = None, + ) -> Array: + return buckets[ + jnp.argmin( + vmap( + lambda delta: self._compute_probability_error( + v=vt, + g=gt, + delta=delta, + scores=scores, + groups=groups, + values=values, + n_buckets=len(buckets), + coverage=coverage, + ) + )(buckets) + ) + ] + + def _patch( + self, values: Array, patch: Array, bt: Array, _shift: bool = True + ) -> Array: + return super()._patch(values=values, patch=patch, bt=bt, _shift=_shift) diff --git a/fortuna/conformal/multivalid/multicalibrator.py b/fortuna/conformal/multivalid/multicalibrator.py new file mode 100644 index 00000000..882b28c3 --- /dev/null +++ b/fortuna/conformal/multivalid/multicalibrator.py @@ -0,0 +1,108 @@ +from __future__ import annotations + +import jax.numpy as jnp + +from fortuna.conformal.multivalid.base import MultivalidMethod +from fortuna.typing import Array + + +class Multicalibrator(MultivalidMethod): + def __init__(self): + """ + A multicalibration method that provides multivalid coverage guarantees. See Algorithm 15 in `Aaron Roth's notes + `_. + """ + super().__init__() + self._patch_list = [] + + def _calibration_error( + self, + v: Array, + g: Array, + scores: Array, + groups: Array, + values: Array, + n_buckets: int, + **kwargs, + ): + expectation_error, prob_b = self._compute_expectation_error( + v=v, + g=g, + scores=scores, + groups=groups, + values=values, + n_buckets=n_buckets, + return_prob_b=True, + ) + return prob_b * expectation_error + + def _compute_expectation_error( + self, + v: Array, + g: Array, + scores: Array, + groups: Array, + values: Array, + n_buckets: int, + return_prob_b: bool = False, + ): + if return_prob_b: + mean, prob_b = self._compute_expectation( + v=v, + g=g, + scores=scores, + groups=groups, + values=values, + n_buckets=n_buckets, + return_prob_b=return_prob_b, + ) + return (v - mean) ** 2, prob_b + mean = self._compute_expectation( + v=v, + g=g, + scores=scores, + groups=groups, + values=values, + n_buckets=n_buckets, + return_prob_b=return_prob_b, + ) + return (v - mean) ** 2 + + def _compute_expectation( + self, + v: Array, + g: Array, + scores: Array, + groups: Array, + values: Array, + n_buckets: int, + return_prob_b: bool = False, + ): + b = self._get_b(groups=groups, values=values, v=v, g=g, n_buckets=n_buckets) + filtered_scores = scores * b + prob_b = jnp.mean(b) + mean = jnp.where(prob_b > 0, jnp.mean(filtered_scores) / prob_b, 0.0) + + if return_prob_b: + return mean, prob_b + return mean + + def _get_patch( + self, + vt: Array, + gt: Array, + scores: Array, + groups: Array, + values: Array, + buckets: Array, + **kwargs, + ) -> Array: + patch = self._compute_expectation( + v=vt, + g=gt, + scores=scores, + groups=groups, + values=values, + n_buckets=len(buckets), + ) + return self._round_to_buckets(patch, buckets) diff --git a/fortuna/conformal/regression/batch_mvp.py b/fortuna/conformal/regression/batch_mvp.py index e5de69b9..ea0dd2e1 100644 --- a/fortuna/conformal/regression/batch_mvp.py +++ b/fortuna/conformal/regression/batch_mvp.py @@ -1,150 +1,15 @@ -from typing import ( - Callable, - List, - Optional, - Tuple, - Union, -) - -import jax.numpy as jnp - -from fortuna.conformal.batch_mvp import BatchMVPConformalMethod +from fortuna.conformal.multivalid.batch_mvp import BatchMVPConformalMethod from fortuna.conformal.regression.base import ConformalRegressor -from fortuna.data.loader import ( - DataLoader, - InputsLoader, -) -from fortuna.typing import Array - - -class Bounds: - def __init__(self, bounds_fn: Callable[[Array, Array], Tuple[Array, Array]]): - self.bounds_fn = bounds_fn - - def __call__(self, x: Array, t: Array) -> Tuple[Array, Array]: - bl, br = self.bounds_fn(x, t) - if bl.ndim > 1: - raise ValueError( - "Evaluations of the bounds function must e a tuple of two one-dimensional arrays. " - f"However, the first array has shape {bl.shape}." - ) - if br.ndim > 1: - raise ValueError( - "Evaluations of the bounds function must be a tuple of two one-dimensional arrays. " - f"However, the second array has shape {bl.shape}." - ) - if len(bl) != len(br): - raise ValueError( - "Evaluations of the bounds function must be a tuple of two one-dimensional arrays " - f"with same length. However, lengths {len(bl)} and {len(br)} were found, " - f"respectively." - ) - return bl, br class BatchMVPConformalRegressor(BatchMVPConformalMethod, ConformalRegressor): def __init__( self, - score_fn: Callable[[Array, Array], Array], - group_fns: List[Callable[[Array], Array]], - bounds_fn: Callable[[Array, Array], Tuple[Array, Array]], - n_buckets: int = 100, ): """ - This class implements a regression version of BatchMVP + This class implements a classification version of BatchMVP `[Jung et al., 2022] `_, a multivalid conformal prediction method that satisfies coverage guarantees conditioned on group membership and non-conformity threshold. - - Parameters - ---------- - score_fn: Callable[[Array, Array], Array] - A score function mapping a batch of inputs and targets to scalar scores, one for each data point. The - score function represents the degree of non-conformity between inputs and targets. In regression, an - example of score function is :math:`s(x,y)=|y - h(x)|`, where `h` is an arbitrary regression model. - group_fns: List[Callable[[Array], Array]] - A list of group functions, each mapping input data points into boolean arrays which determine whether - an input belongs to a certain group or not. As an example, suppose that we are interested in obtaining - marginal coverages guarantee on both negative and positive scalar inputs. - Then we could define groups functions - :math:`g_1(x) = x < 0` and :math:`g_1(x) = x > 0`. - Note that groups can be overlapping, and do not need to cover the full space of inputs. - bounds_fn: Callable[[Array, Array], Array] - A function taking a batch of input data points and respective score thresholds, - and returning a tuple of arrays, respectively lower and upper bounds for each input. - n_buckets: int - The number of buckets that defines the search space between 0 and 1 that determines the updates of the - thresholds for the score function. - """ - super().__init__(score_fn=score_fn, group_fns=group_fns, n_buckets=n_buckets) - self.bounds_fn = Bounds(bounds_fn=bounds_fn) - - def conformal_interval( - self, - val_data_loader: DataLoader, - test_inputs_loader: InputsLoader, - error: float = 0.05, - tol: float = 1e-4, - n_rounds: int = 1000, - return_max_calib_error: bool = False, - test_thresholds: Optional[Array] = None, - ) -> Union[Array, Tuple[Array, List[Array]]]: """ - Compute a conformal interval for each test input. - - Parameters - ---------- - val_data_loader: DataLoader - A data loader of validation data. - test_inputs_loader: InputsLoader - A loader of test input data points. - error: float - A desired coverage error. - tol: float - A tolerance for the maximum calibration error. - n_rounds: int - The maximum number of updates the algorithm will run for. - return_max_calib_error: bool - Whether to return a list of computed maximum calibration error, that is the larger calibration error - over the different groups. - test_thresholds: Optional[Array] - The score thresholds computed over the test data set. These should be the output of - `BatchMVP.threshold_score`. If provided, they will not be recomputed internally. - - Returns - ------- - Union[Array, Tuple[Array, List[Array]]] - The computed conformal intervals for each test input. - Optionally, it returns the maximum calibration errors computed during the algorithm. - """ - if test_thresholds is not None and return_max_calib_error: - raise ValueError( - "If `test_thresholds` is given, `return_max_calib_error` cannot be returned." - ) - if test_thresholds is None: - outs = self.threshold_score( - val_data_loader=val_data_loader, - test_inputs_loader=test_inputs_loader, - error=error, - tol=tol, - n_rounds=n_rounds, - return_max_calib_error=return_max_calib_error, - ) - if return_max_calib_error: - test_thresholds, max_calib_errors = outs - else: - test_thresholds = outs - - c = 0 - intervals = [] - for inputs in test_inputs_loader: - left, right = self.bounds_fn( - inputs, test_thresholds[c : c + inputs.shape[0]] - ) - intervals.append(jnp.stack((left, right), axis=1)) - c += inputs.shape[0] - intervals = jnp.concatenate(intervals, axis=0) - - if return_max_calib_error: - return intervals, max_calib_errors - return intervals + super().__init__() diff --git a/fortuna/metric/classification.py b/fortuna/metric/classification.py index 5fffaa22..3c6c8236 100755 --- a/fortuna/metric/classification.py +++ b/fortuna/metric/classification.py @@ -202,7 +202,7 @@ def brier_score(probs: Array, targets: Union[TargetsLoader, Array]) -> jnp.ndarr Parameters ---------- probs: Array - A two-dimensional array of class probabilities for each data point. + A one- or two-dimensional array of class probabilities for each data point. targets: Array A one-dimensional array of target variables. @@ -211,7 +211,11 @@ def brier_score(probs: Array, targets: Union[TargetsLoader, Array]) -> jnp.ndarr jnp.ndarray The Brier score. """ + if probs.ndim > 2: + raise ValueError("`probs` can be at most 2 dimensional.") if type(targets) == TargetsLoader: targets = targets.to_array_targets() - targets = jax.nn.one_hot(targets, probs.shape[-1]) - return jnp.mean(jnp.sum((probs - targets) ** 2, axis=1)) + if probs.ndim > 1: + targets = jax.nn.one_hot(targets, probs.shape[-1]) + return jnp.mean(jnp.sum((probs - targets) ** 2, axis=-1)) + return jnp.mean((probs - targets) ** 2) diff --git a/pyproject.toml b/pyproject.toml index 5adc22a2..1687732c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "aws-fortuna" -version = "0.1.21" +version = "0.1.22" description = "A Library for Uncertainty Quantification." authors = ["Gianluca Detommaso ", "Alberto Gasparin "] license = "Apache-2.0" diff --git a/tests/fortuna/test_conformal_methods.py b/tests/fortuna/test_conformal_methods.py index 1b3df237..b75b7ff6 100755 --- a/tests/fortuna/test_conformal_methods.py +++ b/tests/fortuna/test_conformal_methods.py @@ -1,5 +1,6 @@ import unittest +from jax import random import jax.numpy as jnp import numpy as np @@ -17,6 +18,7 @@ QuantileConformalRegressor, SimplePredictionConformalClassifier, ) +from fortuna.conformal.multivalid.multicalibrator import Multicalibrator from fortuna.data.loader import ( DataLoader, InputsLoader, @@ -317,41 +319,178 @@ def test_adaptive_conformal_classification(self): ) def test_batchmvp_regressor(self): - batchmvp = BatchMVPConformalRegressor( - score_fn=lambda x, y: jnp.abs(y - x) / 15, - group_fns=[lambda x: x > 0.1, lambda x: x < 0.2, lambda x: x > 0.3], - bounds_fn=lambda x, t: (x - t, x + t), + size = 10 + test_size = 20 + scores = random.uniform(random.PRNGKey(0), shape=(size,)) + groups = random.choice(random.PRNGKey(0), 2, shape=(size, 3)) + values = jnp.zeros(size) + test_scores = random.uniform(random.PRNGKey(0), shape=(test_size,)) + test_groups = random.choice(random.PRNGKey(1), 2, shape=(test_size, 3)) + test_values = jnp.zeros(test_size) + batchmvp = BatchMVPConformalRegressor() + status = batchmvp.calibrate( + scores=scores, groups=groups, n_rounds=3, n_buckets=4 ) - val_data_loader = DataLoader.from_array_data( - (self._rng.normal(size=(50,)), self._rng.normal(size=(50,))), - batch_size=32, + status = batchmvp.calibrate( + scores=scores, groups=groups, values=values, n_rounds=3, n_buckets=4 ) - test_inputs_loader = InputsLoader.from_array_inputs( - self._rng.normal(size=(150,)), - batch_size=32, + test_values, status = batchmvp.calibrate( + scores=scores, + groups=groups, + test_groups=test_groups, + n_rounds=3, + n_buckets=4, + ) + with self.assertRaises(ValueError): + test_values, status = batchmvp.calibrate( + scores=scores, + groups=groups, + values=values, + test_groups=test_groups, + n_rounds=3, + n_buckets=4, + ) + with self.assertRaises(ValueError): + test_values, status = batchmvp.calibrate( + scores=scores, + groups=groups, + values=values, + test_values=test_values, + n_rounds=3, + n_buckets=4, + ) + with self.assertRaises(ValueError): + test_values, status = batchmvp.calibrate( + scores=scores, + groups=groups, + test_groups=test_groups, + test_values=test_values, + n_rounds=3, + n_buckets=4, + ) + status = batchmvp.calibrate( + scores=scores, groups=groups, n_rounds=3, n_buckets=4 + ) + test_values = batchmvp.apply_patches(test_groups) + test_values = batchmvp.apply_patches(test_groups, test_values) + error = batchmvp.calibration_error( + scores=test_scores, groups=test_groups, values=test_values ) - - intervals = batchmvp.conformal_interval(val_data_loader, test_inputs_loader) - assert intervals.shape == (150, 2) def test_batchmvp_classifier(self): - batchmvp = BatchMVPConformalClassifier( - score_fn=lambda x, y: 1 - jnp.mean(x, 0)[y], - group_fns=[ - lambda x: x[:, 0] > 0.1, - lambda x: x[:, 0] < 0.2, - lambda x: x[:, 0] > 0.3, - ], - n_classes=2, - ) - val_data_loader = DataLoader.from_array_data( - (self._rng.normal(size=(50, 1)), self._rng.choice(2, 50)), - batch_size=32, - ) - test_inputs_loader = InputsLoader.from_array_inputs( - self._rng.normal(size=(150, 1)), - batch_size=32, - ) - - sets = batchmvp.conformal_set(val_data_loader, test_inputs_loader) - assert len(sets) == 150 + size = 10 + test_size = 20 + scores = random.uniform(random.PRNGKey(0), shape=(size,)) + groups = random.choice(random.PRNGKey(0), 2, shape=(size, 3)) + values = jnp.zeros(size) + test_scores = random.uniform(random.PRNGKey(0), shape=(test_size,)) + test_groups = random.choice(random.PRNGKey(1), 2, shape=(test_size, 3)) + batchmvp = BatchMVPConformalClassifier() + status = batchmvp.calibrate( + scores=scores, groups=groups, n_rounds=3, n_buckets=4 + ) + status = batchmvp.calibrate( + scores=scores, groups=groups, values=values, n_rounds=3, n_buckets=4 + ) + test_values, status = batchmvp.calibrate( + scores=scores, + groups=groups, + test_groups=test_groups, + n_rounds=3, + n_buckets=4, + ) + with self.assertRaises(ValueError): + test_values, status = batchmvp.calibrate( + scores=scores, + groups=groups, + values=values, + test_groups=test_groups, + n_rounds=3, + n_buckets=4, + ) + with self.assertRaises(ValueError): + test_values, status = batchmvp.calibrate( + scores=scores, + groups=groups, + values=values, + test_values=test_values, + n_rounds=3, + n_buckets=4, + ) + with self.assertRaises(ValueError): + test_values, status = batchmvp.calibrate( + scores=scores, + groups=groups, + test_groups=test_groups, + test_values=test_values, + n_rounds=3, + n_buckets=4, + ) + status = batchmvp.calibrate( + scores=scores, groups=groups, n_rounds=3, n_buckets=4 + ) + test_values = batchmvp.apply_patches(test_groups) + test_values = batchmvp.apply_patches(test_groups, test_values) + error = batchmvp.calibration_error( + scores=test_scores, groups=test_groups, values=test_values + ) + + sets = batchmvp.conformal_set( + class_scores=jnp.stack((test_scores, test_scores), axis=1), + values=test_values, + ) + assert len(sets) == test_size + + def test_multicalibrator(self): + size = 10 + test_size = 20 + scores = random.uniform(random.PRNGKey(0), shape=(size,)) + groups = random.choice(random.PRNGKey(0), 2, shape=(size, 3)) + values = jnp.zeros(size) + test_scores = random.uniform(random.PRNGKey(0), shape=(test_size,)) + test_groups = random.choice(random.PRNGKey(1), 2, shape=(test_size, 3)) + mc = Multicalibrator() + status = mc.calibrate(scores=scores, groups=groups, n_rounds=3, n_buckets=4) + status = mc.calibrate( + scores=scores, groups=groups, values=values, n_rounds=3, n_buckets=4 + ) + test_values, status = mc.calibrate( + scores=scores, + groups=groups, + test_groups=test_groups, + n_rounds=3, + n_buckets=4, + ) + with self.assertRaises(ValueError): + test_values, status = mc.calibrate( + scores=scores, + groups=groups, + values=values, + test_groups=test_groups, + n_rounds=3, + n_buckets=4, + ) + with self.assertRaises(ValueError): + test_values, status = mc.calibrate( + scores=scores, + groups=groups, + values=values, + test_values=test_values, + n_rounds=3, + n_buckets=4, + ) + with self.assertRaises(ValueError): + test_values, status = mc.calibrate( + scores=scores, + groups=groups, + test_groups=test_groups, + test_values=test_values, + n_rounds=3, + n_buckets=4, + ) + status = mc.calibrate(scores=scores, groups=groups, n_rounds=3, n_buckets=4) + test_values = mc.apply_patches(test_groups) + test_values = mc.apply_patches(test_groups, test_values) + error = mc.calibration_error( + scores=test_scores, groups=test_groups, values=test_values + )