Skip to content

Commit

Permalink
Add multicalibrator (#109)
Browse files Browse the repository at this point in the history
* edit installation instructions in readme

* bump up version

* make small change in readme because of publish to pypi error

* add batch multivalid mean predictor

* go forward

* continue

* add multicalibrator and refactor batchmvp

* black code

* bump up version
  • Loading branch information
gianlucadetommaso authored Jul 30, 2023
1 parent 2fa48c7 commit 31db6f7
Show file tree
Hide file tree
Showing 15 changed files with 1,110 additions and 513 deletions.
150 changes: 150 additions & 0 deletions benchmarks/two_moons_multicalibrate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
import flax.linen as nn
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import optax
from sklearn.datasets import make_moons

from fortuna.conformal import Multicalibrator
from fortuna.data import (
DataLoader,
InputsLoader,
)
from fortuna.metric.classification import accuracy
from fortuna.model.mlp import MLP
from fortuna.prob_model import (
CalibConfig,
CalibMonitor,
FitConfig,
FitMonitor,
FitOptimizer,
MAPPosteriorApproximator,
ProbClassifier,
)

train_data = make_moons(n_samples=5000, noise=0.07, random_state=0)
val_data = make_moons(n_samples=1000, noise=0.07, random_state=1)
test_data = make_moons(n_samples=1000, noise=0.07, random_state=2)

train_data_loader = DataLoader.from_array_data(
train_data, batch_size=128, shuffle=True, prefetch=True
)
val_data_loader = DataLoader.from_array_data(val_data, batch_size=128, prefetch=True)
test_data_loader = DataLoader.from_array_data(test_data, batch_size=128, prefetch=True)

output_dim = 2
prob_model = ProbClassifier(
model=MLP(output_dim=output_dim, activations=(nn.tanh, nn.tanh)),
posterior_approximator=MAPPosteriorApproximator(),
)

status = prob_model.train(
train_data_loader=train_data_loader,
val_data_loader=val_data_loader,
calib_data_loader=val_data_loader,
fit_config=FitConfig(
monitor=FitMonitor(metrics=(accuracy,), early_stopping_patience=10),
optimizer=FitOptimizer(method=optax.adam(1e-4), n_epochs=10),
),
calib_config=CalibConfig(monitor=CalibMonitor(early_stopping_patience=2)),
)

test_inputs_loader = test_data_loader.to_inputs_loader()
test_means = prob_model.predictive.mean(inputs_loader=test_inputs_loader)
test_modes = prob_model.predictive.mode(
inputs_loader=test_inputs_loader, means=test_means
)

fig = plt.figure(figsize=(6, 3))
size = 150
xx = np.linspace(-4, 4, size)
yy = np.linspace(-4, 4, size)
grid = np.array([[_xx, _yy] for _xx in xx for _yy in yy])
grid_loader = InputsLoader.from_array_inputs(grid)
grid_entropies = prob_model.predictive.entropy(grid_loader).reshape(size, size)
grid = grid.reshape(size, size, 2)
plt.title("Predictions and entropy", fontsize=12)
im = plt.pcolor(grid[:, :, 0], grid[:, :, 1], grid_entropies)
plt.scatter(
test_data[0][:, 0],
test_data[0][:, 1],
s=1,
c=["C0" if i == 1 else "C1" for i in test_modes],
)
plt.colorbar()
plt.show()

val_inputs_loader = val_data_loader.to_inputs_loader()
test_inputs_loader = test_data_loader.to_inputs_loader()
val_targets = val_data_loader.to_array_targets()
test_targets = test_data_loader.to_array_targets()

val_means = prob_model.predictive.mean(val_inputs_loader)
test_means = prob_model.predictive.mean(val_inputs_loader)

mc = Multicalibrator()
scores = val_targets
test_scores = test_targets
groups = jnp.stack((val_means.argmax(1) == 0, val_means.argmax(1) == 1), axis=1)
test_groups = jnp.stack((test_means.argmax(1) == 0, test_means.argmax(1) == 1), axis=1)
values = val_means[:, 1]
test_values = test_means[:, 1]
calib_test_values, status = mc.calibrate(
scores=scores,
groups=groups,
values=values,
test_groups=test_groups,
test_values=test_values,
n_buckets=1000,
)

plt.figure(figsize=(10, 3))
plt.suptitle("Multivalid calibration of probability that Y=1")
plt.subplot(1, 3, 1)
plt.title("all test inputs")
plt.hist([test_values, calib_test_values])[-1]
plt.legend(["before calibration", "after calibration"])
plt.xlabel("prob")
plt.subplot(1, 3, 2)
plt.title("inputs for which we predict 0")
plt.hist([test_values[test_groups[:, 0]], calib_test_values[test_groups[:, 0]]])[-1]
plt.xlabel("prob")
plt.subplot(1, 3, 3)
plt.title("inputs for which we predict 1")
plt.hist([test_values[test_groups[:, 1]], calib_test_values[test_groups[:, 1]]])[-1]
plt.xlabel("prob")
plt.tight_layout()
plt.show()

plt.title("Max calibration error decay during calibration")
plt.semilogy(status["max_calib_errors"])
plt.show()

print(
"Per-group reweighted avg. squared calib. error before calibration: ",
mc.calibration_error(
scores=test_scores, groups=test_groups, values=test_means.max(1)
),
)
print(
"Per-group reweighted avg. squared calib. error after calibration: ",
mc.calibration_error(
scores=test_scores, groups=test_groups, values=calib_test_values
),
)

print(
"Mismatch between labels and probs before calibration: ",
jnp.mean(
jnp.maximum((1 - test_targets) * test_values, test_targets * (1 - test_values))
),
)
print(
"Mismatch between labels and probs after calibration: ",
jnp.mean(
jnp.maximum(
(1 - test_targets) * calib_test_values,
test_targets * (1 - calib_test_values),
)
),
)
13 changes: 11 additions & 2 deletions docs/source/methods.rst
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,14 @@ For classification:
sequential prediction framework (e.g. time series forecasting) when the distribution of the data shifts over time.

- **BatchMVP** `[Jung C. et al., 2022] <https://arxiv.org/pdf/2209.15145.pdf>`_
a conformal prediction algorithm that satisfies coverage guarantees conditioned on group membership and
A conformal prediction algorithm that satisfies coverage guarantees conditioned on group membership and
non-conformity thresholds.

- **Multicalibrate** `[Hébert-Johnson Ú. et al., 2017] <https://arxiv.org/abs/1711.08513>`_, `[Roth A., Algorithm 15] <https://www.cis.upenn.edu/~aaroth/uncertainty-notes.pdf>`_
Unlike standard conformal prediction methods, this algorithm returns scalar calibrated score values for each data point.
For example, in binary classification, it can return calibrated probabilities of predictions.
This method satisfies coverage guarantees conditioned on group membership and non-conformity thresholds.

For regression:

- **Conformalized quantile regression** `[Romano et al., 2019] <https://proceedings.neurips.cc/paper/2019/file/5103c3584b063c431bd1268e9b5e76fb-Paper.pdf>`_
Expand All @@ -79,12 +84,16 @@ For regression:
satisfying minimal coverage properties.

- **BatchMVP** `[Jung C. et al., 2022] <https://arxiv.org/pdf/2209.15145.pdf>`_
a conformal prediction algorithm that satisfies coverage guarantees conditioned on group membership and
A conformal prediction algorithm that satisfies coverage guarantees conditioned on group membership and
non-conformity thresholds.

- **EnbPI** `[Xu et al., 2021] <http://proceedings.mlr.press/v139/xu21h/xu21h.pdf>`_
A conformal prediction method for time series regression based on data bootstrapping.

- **Multicalibrate** `[Hébert-Johnson Ú. et al., 2017] <https://arxiv.org/abs/1711.08513>`_, `[Roth A., Algorithm 15] <https://www.cis.upenn.edu/~aaroth/uncertainty-notes.pdf>`_
Unlike standard conformal prediction methods, this algorithm returns scalar calibrated score values for each data point.
This method satisfies coverage guarantees conditioned on group membership and non-conformity thresholds.

- **Adaptive conformal inference** `[Gibbs et al., 2021] <https://proceedings.neurips.cc/paper/2021/hash/0d441de75945e5acbc865406fc9a2559-Abstract.html>`_
A method for conformal prediction that aims at correcting the coverage of conformal prediction methods in a
sequential prediction framework (e.g. time series forecasting) when the distribution of the data shifts over time.
4 changes: 4 additions & 0 deletions docs/source/references/conformal.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ and :ref:`regression <conformal_regression>`.

.. automodule:: fortuna.conformal.classification.batch_mvp

.. automodule:: fortuna.conformal.classification.multicalibrator

.. _conformal_regression:

.. automodule:: fortuna.conformal.regression.quantile
Expand All @@ -33,3 +35,5 @@ and :ref:`regression <conformal_regression>`.
.. automodule:: fortuna.conformal.regression.adaptive_conformal_regressor

.. automodule:: fortuna.conformal.regression.batch_mvp

.. automodule:: fortuna.conformal.regression.multicalibrator
63 changes: 38 additions & 25 deletions examples/multivalid_coverage.pct.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,48 +206,52 @@ def plot_intervals(xx, means, intervals, test_data, method):
# %% [markdown]
# We finally introduce Batch MVP [[Jung C. et al., 2022]](https://arxiv.org/pdf/2209.15145.pdf) and show that it improves group-conditional coverage. For its usage, we require:
#
# - a valid non-conformity score function. This can be any score function measuring the degree of non-conformity between inputs $x$ and targets $y$. The less $x$ and $y$ conform with each other, the larger the score should be. A simple example of score function in regression is $s(x,y)=|y - h(x)|$, where $h$ is an arbitrary model. For the purpose of this example, we use the same score function as in CQR, that is $s(x,y)=\max\left(q_{\frac{\alpha}{2}} - y, y - q_{1 - \frac{\alpha}{2}}\right)$, where $\alpha$ is the desired coverage error, i.e. $\alpha=0.05$, and $q_\alpha$ is a corresponding quantile.
# - non-conformity scores evaluated on calibration. These can be evaluations of any score function measuring the degree of non-conformity between inputs $x$ and targets $y$. The less $x$ and $y$ conform with each other, the larger the score should be. A simple example of score function in regression is $s(x,y)=|y - h(x)|$, where $h$ is an arbitrary model. For the purpose of this example, we use the same score function as in CQR, that is $s(x,y)=\max\left(q_{\frac{\alpha}{2}} - y, y - q_{1 - \frac{\alpha}{2}}\right)$, where $\alpha$ is the desired coverage error, i.e. $\alpha=0.05$, and $q_\alpha$ is a corresponding quantile.
#
# - the group functions. These construct sub-domains of interest of the input domain. As we defined above, here we use $g_1(x) = \mathbb{1}[x < 0]$ and $g_2(x) = \mathbb{1}[x \ge 0]$.
#
# - the bounds function. This is a function $b(x, \tau)$ that simultaneously defines the lower and upper bounds of the conformal interval given an input $x$ and a threshold $\tau$. For example, for the score function in use, we have $b(x, \tau) = [q_{\frac{\alpha}{2}} - \tau, q_{1 - \frac{\alpha}{2}} + \tau]$. Please notice that the bounds function is related to the inverse score function with respect to $y$. In fact, it defines the two extreme values of $y$ that satisfy the relation $s(x, y) \le \tau$.
# - group evaluations on calibration and test data. These construct sub-domains of interest of the input domain. As we defined above, here we use $g_1(x) = \mathbb{1}[x < 0]$ and $g_2(x) = \mathbb{1}[x \ge 0]$.
#
# That's it! Defined these, we are ready to run `BatchMVP`.

# %%
from fortuna.conformal.regression.batch_mvp import BatchMVPConformalRegressor
import jax.numpy as jnp


def score_fn(x, y):
qleft, qright = prob_model.predictive.quantile(
[0.05 / 2, 1 - 0.05 / 2], InputsLoader.from_array_inputs(x)
)
return jnp.maximum(qleft - y, y - qright).squeeze(1)


def bounds_fn(x, t):
qleft, qright = prob_model.predictive.quantile(
[0.05 / 2, 1 - 0.05 / 2], InputsLoader.from_array_inputs(x)
)
return qleft.squeeze(1) - t, qright.squeeze(1) + t


batchmvp = BatchMVPConformalRegressor(
score_fn=score_fn, group_fns=group_fns, bounds_fn=bounds_fn
qleft, qright = prob_model.predictive.quantile(
[0.05 / 2, 1 - 0.05 / 2], calib_inputs_loader
)
test_batchmvp_intervals, max_calib_errors = batchmvp.conformal_interval(
calib_data_loader, test_inputs_loader, return_max_calib_error=True
scores = jnp.maximum(qleft - calib_targets, calib_targets - qright).squeeze(1)
min_score, max_score = scores.min(), scores.max()
scores = (scores - min_score) / (max_score - min_score)
groups = jnp.stack([g(calib_data[0]) for g in group_fns], axis=1)
test_groups = jnp.stack([g(test_data[0]) for g in group_fns], axis=1)

batchmvp = BatchMVPConformalRegressor()
test_thresholds, status = batchmvp.calibrate(
scores=scores, groups=groups, test_groups=test_groups
)
test_thresholds = min_score + (max_score - min_score) * test_thresholds

# %% [markdown]
# At each iteration, `BatchMVP` we compute the maximum calibration error over the different groups. We report its decay in the following picture.

# %%
plt.figure(figsize=(6, 3))
plt.plot(max_calib_errors, label="maximum calibration error decay")
plt.plot(status["max_calib_errors"], label="maximum calibration error decay")
plt.xlabel("rounds")
plt.legend()
plt.show()

# %% [markdown]
# Given the test thresholds, we can find the lower and upper bounds of the conformal intervals by inverting the score function $s(x, y)$ with respect to $y$. This gives $b(x, \tau) = [q_{\frac{\alpha}{2}}(x) - \tau, q_{1 - \frac{\alpha}{2}}(x) + \tau]$, where $\tau$ denotes the thresholds.

# %%
test_qleft, test_qright = prob_model.predictive.quantile(
[0.05 / 2, 1 - 0.05 / 2], test_inputs_loader
)
test_qleft, test_qright = test_qleft.squeeze(1), test_qright.squeeze(1)
test_batchmvp_intervals = jnp.stack(
(test_qleft - test_thresholds, test_qright + test_thresholds), axis=1
)

# %% [markdown]
# We now compute coverage metrics. As expected, `BatchMVP` not only provides a good marginal coverage overall, but also improves coverage on both negative and positive inputs.
Expand All @@ -268,5 +272,14 @@ def bounds_fn(x, t):
# Once again, we visualize predictions and estimated intervals.

# %%
xx_batchmvp_intervals = batchmvp.conformal_interval(calib_data_loader, xx_loader)
xx_qleft, xx_qright = prob_model.predictive.quantile(
[0.05 / 2, 1 - 0.05 / 2], InputsLoader.from_array_inputs(xx)
)
xx_qleft, xx_qright = xx_qleft.squeeze(1), xx_qright.squeeze(1)
xx_groups = jnp.stack([g(xx) for g in group_fns], axis=1)
xx_thresholds = batchmvp.apply_patches(groups=xx_groups)
xx_thresholds = min_score + (max_score - min_score) * xx_thresholds
xx_batchmvp_intervals = jnp.stack(
(xx_qleft - xx_thresholds, xx_qright + xx_thresholds), axis=1
)
plot_intervals(xx, xx_means, xx_batchmvp_intervals, test_data, "BatchMVP")
1 change: 1 addition & 0 deletions fortuna/conformal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from fortuna.conformal.classification.simple_prediction import (
SimplePredictionConformalClassifier,
)
from fortuna.conformal.multivalid.multicalibrator import Multicalibrator
from fortuna.conformal.regression.adaptive_conformal_regressor import (
AdaptiveConformalRegressor,
)
Expand Down
Loading

0 comments on commit 31db6f7

Please sign in to comment.