Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add multicalibrator #109

Merged
merged 36 commits into from
Jul 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
52e96ea
edit installation instructions in readme
gianlucadetommaso May 15, 2023
5e0076d
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso May 15, 2023
4c7fd28
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso May 15, 2023
6cb6581
bump up version
gianlucadetommaso May 15, 2023
1b39780
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso May 16, 2023
cb2b49a
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso May 16, 2023
14e3ca4
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso May 25, 2023
580067d
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso May 27, 2023
048ef09
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Jun 2, 2023
ad542a4
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Jun 12, 2023
41417c1
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Jun 12, 2023
64be374
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Jun 14, 2023
a2d0f34
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Jun 14, 2023
66bba06
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Jun 15, 2023
911aa82
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Jun 15, 2023
01f959b
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Jun 15, 2023
79f8dca
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Jun 15, 2023
4dea50f
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Jun 21, 2023
1ced008
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Jul 18, 2023
6992692
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Jul 18, 2023
b2540c1
make small change in readme because of publish to pypi error
gianlucadetommaso Jul 18, 2023
2362998
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Jul 18, 2023
6e030f2
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Jul 25, 2023
3a194d2
add batch multivalid mean predictor
gianlucadetommaso Jul 25, 2023
9bd6f67
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Jul 25, 2023
c5bc94f
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Jul 25, 2023
ff07864
Merge branch 'main' of https://github.com/awslabs/fortuna into mvmp
gianlucadetommaso Jul 25, 2023
5c0d0e5
go forward
gianlucadetommaso Jul 26, 2023
d3ab46b
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Jul 26, 2023
1d0fa81
Merge branch 'main' into mvmp
gianlucadetommaso Jul 26, 2023
95de324
continue
gianlucadetommaso Jul 26, 2023
0e2aca5
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Jul 26, 2023
26eb0a9
Merge branch 'main' into mvmp
gianlucadetommaso Jul 26, 2023
89c5cf9
add multicalibrator and refactor batchmvp
gianlucadetommaso Jul 28, 2023
1963ac3
black code
gianlucadetommaso Jul 30, 2023
f2e5d2a
bump up version
gianlucadetommaso Jul 30, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 150 additions & 0 deletions benchmarks/two_moons_multicalibrate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
import flax.linen as nn
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import optax
from sklearn.datasets import make_moons

from fortuna.conformal import Multicalibrator
from fortuna.data import (
DataLoader,
InputsLoader,
)
from fortuna.metric.classification import accuracy
from fortuna.model.mlp import MLP
from fortuna.prob_model import (
CalibConfig,
CalibMonitor,
FitConfig,
FitMonitor,
FitOptimizer,
MAPPosteriorApproximator,
ProbClassifier,
)

train_data = make_moons(n_samples=5000, noise=0.07, random_state=0)
val_data = make_moons(n_samples=1000, noise=0.07, random_state=1)
test_data = make_moons(n_samples=1000, noise=0.07, random_state=2)

train_data_loader = DataLoader.from_array_data(
train_data, batch_size=128, shuffle=True, prefetch=True
)
val_data_loader = DataLoader.from_array_data(val_data, batch_size=128, prefetch=True)
test_data_loader = DataLoader.from_array_data(test_data, batch_size=128, prefetch=True)

output_dim = 2
prob_model = ProbClassifier(
model=MLP(output_dim=output_dim, activations=(nn.tanh, nn.tanh)),
posterior_approximator=MAPPosteriorApproximator(),
)

status = prob_model.train(
train_data_loader=train_data_loader,
val_data_loader=val_data_loader,
calib_data_loader=val_data_loader,
fit_config=FitConfig(
monitor=FitMonitor(metrics=(accuracy,), early_stopping_patience=10),
optimizer=FitOptimizer(method=optax.adam(1e-4), n_epochs=10),
),
calib_config=CalibConfig(monitor=CalibMonitor(early_stopping_patience=2)),
)

test_inputs_loader = test_data_loader.to_inputs_loader()
test_means = prob_model.predictive.mean(inputs_loader=test_inputs_loader)
test_modes = prob_model.predictive.mode(
inputs_loader=test_inputs_loader, means=test_means
)

fig = plt.figure(figsize=(6, 3))
size = 150
xx = np.linspace(-4, 4, size)
yy = np.linspace(-4, 4, size)
grid = np.array([[_xx, _yy] for _xx in xx for _yy in yy])
grid_loader = InputsLoader.from_array_inputs(grid)
grid_entropies = prob_model.predictive.entropy(grid_loader).reshape(size, size)
grid = grid.reshape(size, size, 2)
plt.title("Predictions and entropy", fontsize=12)
im = plt.pcolor(grid[:, :, 0], grid[:, :, 1], grid_entropies)
plt.scatter(
test_data[0][:, 0],
test_data[0][:, 1],
s=1,
c=["C0" if i == 1 else "C1" for i in test_modes],
)
plt.colorbar()
plt.show()

val_inputs_loader = val_data_loader.to_inputs_loader()
test_inputs_loader = test_data_loader.to_inputs_loader()
val_targets = val_data_loader.to_array_targets()
test_targets = test_data_loader.to_array_targets()

val_means = prob_model.predictive.mean(val_inputs_loader)
test_means = prob_model.predictive.mean(val_inputs_loader)

mc = Multicalibrator()
scores = val_targets
test_scores = test_targets
groups = jnp.stack((val_means.argmax(1) == 0, val_means.argmax(1) == 1), axis=1)
test_groups = jnp.stack((test_means.argmax(1) == 0, test_means.argmax(1) == 1), axis=1)
values = val_means[:, 1]
test_values = test_means[:, 1]
calib_test_values, status = mc.calibrate(
scores=scores,
groups=groups,
values=values,
test_groups=test_groups,
test_values=test_values,
n_buckets=1000,
)

plt.figure(figsize=(10, 3))
plt.suptitle("Multivalid calibration of probability that Y=1")
plt.subplot(1, 3, 1)
plt.title("all test inputs")
plt.hist([test_values, calib_test_values])[-1]
plt.legend(["before calibration", "after calibration"])
plt.xlabel("prob")
plt.subplot(1, 3, 2)
plt.title("inputs for which we predict 0")
plt.hist([test_values[test_groups[:, 0]], calib_test_values[test_groups[:, 0]]])[-1]
plt.xlabel("prob")
plt.subplot(1, 3, 3)
plt.title("inputs for which we predict 1")
plt.hist([test_values[test_groups[:, 1]], calib_test_values[test_groups[:, 1]]])[-1]
plt.xlabel("prob")
plt.tight_layout()
plt.show()

plt.title("Max calibration error decay during calibration")
plt.semilogy(status["max_calib_errors"])
plt.show()

print(
"Per-group reweighted avg. squared calib. error before calibration: ",
mc.calibration_error(
scores=test_scores, groups=test_groups, values=test_means.max(1)
),
)
print(
"Per-group reweighted avg. squared calib. error after calibration: ",
mc.calibration_error(
scores=test_scores, groups=test_groups, values=calib_test_values
),
)

print(
"Mismatch between labels and probs before calibration: ",
jnp.mean(
jnp.maximum((1 - test_targets) * test_values, test_targets * (1 - test_values))
),
)
print(
"Mismatch between labels and probs after calibration: ",
jnp.mean(
jnp.maximum(
(1 - test_targets) * calib_test_values,
test_targets * (1 - calib_test_values),
)
),
)
13 changes: 11 additions & 2 deletions docs/source/methods.rst
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,14 @@ For classification:
sequential prediction framework (e.g. time series forecasting) when the distribution of the data shifts over time.

- **BatchMVP** `[Jung C. et al., 2022] <https://arxiv.org/pdf/2209.15145.pdf>`_
a conformal prediction algorithm that satisfies coverage guarantees conditioned on group membership and
A conformal prediction algorithm that satisfies coverage guarantees conditioned on group membership and
non-conformity thresholds.

- **Multicalibrate** `[Hébert-Johnson Ú. et al., 2017] <https://arxiv.org/abs/1711.08513>`_, `[Roth A., Algorithm 15] <https://www.cis.upenn.edu/~aaroth/uncertainty-notes.pdf>`_
Unlike standard conformal prediction methods, this algorithm returns scalar calibrated score values for each data point.
For example, in binary classification, it can return calibrated probabilities of predictions.
This method satisfies coverage guarantees conditioned on group membership and non-conformity thresholds.

For regression:

- **Conformalized quantile regression** `[Romano et al., 2019] <https://proceedings.neurips.cc/paper/2019/file/5103c3584b063c431bd1268e9b5e76fb-Paper.pdf>`_
Expand All @@ -79,12 +84,16 @@ For regression:
satisfying minimal coverage properties.

- **BatchMVP** `[Jung C. et al., 2022] <https://arxiv.org/pdf/2209.15145.pdf>`_
a conformal prediction algorithm that satisfies coverage guarantees conditioned on group membership and
A conformal prediction algorithm that satisfies coverage guarantees conditioned on group membership and
non-conformity thresholds.

- **EnbPI** `[Xu et al., 2021] <http://proceedings.mlr.press/v139/xu21h/xu21h.pdf>`_
A conformal prediction method for time series regression based on data bootstrapping.

- **Multicalibrate** `[Hébert-Johnson Ú. et al., 2017] <https://arxiv.org/abs/1711.08513>`_, `[Roth A., Algorithm 15] <https://www.cis.upenn.edu/~aaroth/uncertainty-notes.pdf>`_
Unlike standard conformal prediction methods, this algorithm returns scalar calibrated score values for each data point.
This method satisfies coverage guarantees conditioned on group membership and non-conformity thresholds.

- **Adaptive conformal inference** `[Gibbs et al., 2021] <https://proceedings.neurips.cc/paper/2021/hash/0d441de75945e5acbc865406fc9a2559-Abstract.html>`_
A method for conformal prediction that aims at correcting the coverage of conformal prediction methods in a
sequential prediction framework (e.g. time series forecasting) when the distribution of the data shifts over time.
4 changes: 4 additions & 0 deletions docs/source/references/conformal.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ and :ref:`regression <conformal_regression>`.

.. automodule:: fortuna.conformal.classification.batch_mvp

.. automodule:: fortuna.conformal.classification.multicalibrator

.. _conformal_regression:

.. automodule:: fortuna.conformal.regression.quantile
Expand All @@ -33,3 +35,5 @@ and :ref:`regression <conformal_regression>`.
.. automodule:: fortuna.conformal.regression.adaptive_conformal_regressor

.. automodule:: fortuna.conformal.regression.batch_mvp

.. automodule:: fortuna.conformal.regression.multicalibrator
63 changes: 38 additions & 25 deletions examples/multivalid_coverage.pct.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,48 +206,52 @@ def plot_intervals(xx, means, intervals, test_data, method):
# %% [markdown]
# We finally introduce Batch MVP [[Jung C. et al., 2022]](https://arxiv.org/pdf/2209.15145.pdf) and show that it improves group-conditional coverage. For its usage, we require:
#
# - a valid non-conformity score function. This can be any score function measuring the degree of non-conformity between inputs $x$ and targets $y$. The less $x$ and $y$ conform with each other, the larger the score should be. A simple example of score function in regression is $s(x,y)=|y - h(x)|$, where $h$ is an arbitrary model. For the purpose of this example, we use the same score function as in CQR, that is $s(x,y)=\max\left(q_{\frac{\alpha}{2}} - y, y - q_{1 - \frac{\alpha}{2}}\right)$, where $\alpha$ is the desired coverage error, i.e. $\alpha=0.05$, and $q_\alpha$ is a corresponding quantile.
# - non-conformity scores evaluated on calibration. These can be evaluations of any score function measuring the degree of non-conformity between inputs $x$ and targets $y$. The less $x$ and $y$ conform with each other, the larger the score should be. A simple example of score function in regression is $s(x,y)=|y - h(x)|$, where $h$ is an arbitrary model. For the purpose of this example, we use the same score function as in CQR, that is $s(x,y)=\max\left(q_{\frac{\alpha}{2}} - y, y - q_{1 - \frac{\alpha}{2}}\right)$, where $\alpha$ is the desired coverage error, i.e. $\alpha=0.05$, and $q_\alpha$ is a corresponding quantile.
#
# - the group functions. These construct sub-domains of interest of the input domain. As we defined above, here we use $g_1(x) = \mathbb{1}[x < 0]$ and $g_2(x) = \mathbb{1}[x \ge 0]$.
#
# - the bounds function. This is a function $b(x, \tau)$ that simultaneously defines the lower and upper bounds of the conformal interval given an input $x$ and a threshold $\tau$. For example, for the score function in use, we have $b(x, \tau) = [q_{\frac{\alpha}{2}} - \tau, q_{1 - \frac{\alpha}{2}} + \tau]$. Please notice that the bounds function is related to the inverse score function with respect to $y$. In fact, it defines the two extreme values of $y$ that satisfy the relation $s(x, y) \le \tau$.
# - group evaluations on calibration and test data. These construct sub-domains of interest of the input domain. As we defined above, here we use $g_1(x) = \mathbb{1}[x < 0]$ and $g_2(x) = \mathbb{1}[x \ge 0]$.
#
# That's it! Defined these, we are ready to run `BatchMVP`.

# %%
from fortuna.conformal.regression.batch_mvp import BatchMVPConformalRegressor
import jax.numpy as jnp


def score_fn(x, y):
qleft, qright = prob_model.predictive.quantile(
[0.05 / 2, 1 - 0.05 / 2], InputsLoader.from_array_inputs(x)
)
return jnp.maximum(qleft - y, y - qright).squeeze(1)


def bounds_fn(x, t):
qleft, qright = prob_model.predictive.quantile(
[0.05 / 2, 1 - 0.05 / 2], InputsLoader.from_array_inputs(x)
)
return qleft.squeeze(1) - t, qright.squeeze(1) + t


batchmvp = BatchMVPConformalRegressor(
score_fn=score_fn, group_fns=group_fns, bounds_fn=bounds_fn
qleft, qright = prob_model.predictive.quantile(
[0.05 / 2, 1 - 0.05 / 2], calib_inputs_loader
)
test_batchmvp_intervals, max_calib_errors = batchmvp.conformal_interval(
calib_data_loader, test_inputs_loader, return_max_calib_error=True
scores = jnp.maximum(qleft - calib_targets, calib_targets - qright).squeeze(1)
min_score, max_score = scores.min(), scores.max()
scores = (scores - min_score) / (max_score - min_score)
groups = jnp.stack([g(calib_data[0]) for g in group_fns], axis=1)
test_groups = jnp.stack([g(test_data[0]) for g in group_fns], axis=1)

batchmvp = BatchMVPConformalRegressor()
test_thresholds, status = batchmvp.calibrate(
scores=scores, groups=groups, test_groups=test_groups
)
test_thresholds = min_score + (max_score - min_score) * test_thresholds

# %% [markdown]
# At each iteration, `BatchMVP` we compute the maximum calibration error over the different groups. We report its decay in the following picture.

# %%
plt.figure(figsize=(6, 3))
plt.plot(max_calib_errors, label="maximum calibration error decay")
plt.plot(status["max_calib_errors"], label="maximum calibration error decay")
plt.xlabel("rounds")
plt.legend()
plt.show()

# %% [markdown]
# Given the test thresholds, we can find the lower and upper bounds of the conformal intervals by inverting the score function $s(x, y)$ with respect to $y$. This gives $b(x, \tau) = [q_{\frac{\alpha}{2}}(x) - \tau, q_{1 - \frac{\alpha}{2}}(x) + \tau]$, where $\tau$ denotes the thresholds.

# %%
test_qleft, test_qright = prob_model.predictive.quantile(
[0.05 / 2, 1 - 0.05 / 2], test_inputs_loader
)
test_qleft, test_qright = test_qleft.squeeze(1), test_qright.squeeze(1)
test_batchmvp_intervals = jnp.stack(
(test_qleft - test_thresholds, test_qright + test_thresholds), axis=1
)

# %% [markdown]
# We now compute coverage metrics. As expected, `BatchMVP` not only provides a good marginal coverage overall, but also improves coverage on both negative and positive inputs.
Expand All @@ -268,5 +272,14 @@ def bounds_fn(x, t):
# Once again, we visualize predictions and estimated intervals.

# %%
xx_batchmvp_intervals = batchmvp.conformal_interval(calib_data_loader, xx_loader)
xx_qleft, xx_qright = prob_model.predictive.quantile(
[0.05 / 2, 1 - 0.05 / 2], InputsLoader.from_array_inputs(xx)
)
xx_qleft, xx_qright = xx_qleft.squeeze(1), xx_qright.squeeze(1)
xx_groups = jnp.stack([g(xx) for g in group_fns], axis=1)
xx_thresholds = batchmvp.apply_patches(groups=xx_groups)
xx_thresholds = min_score + (max_score - min_score) * xx_thresholds
xx_batchmvp_intervals = jnp.stack(
(xx_qleft - xx_thresholds, xx_qright + xx_thresholds), axis=1
)
plot_intervals(xx, xx_means, xx_batchmvp_intervals, test_data, "BatchMVP")
1 change: 1 addition & 0 deletions fortuna/conformal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from fortuna.conformal.classification.simple_prediction import (
SimplePredictionConformalClassifier,
)
from fortuna.conformal.multivalid.multicalibrator import Multicalibrator
from fortuna.conformal.regression.adaptive_conformal_regressor import (
AdaptiveConformalRegressor,
)
Expand Down
Loading
Loading