Skip to content

Commit

Permalink
black code
Browse files Browse the repository at this point in the history
  • Loading branch information
gianlucadetommaso committed Jul 30, 2023
1 parent 89c5cf9 commit 1963ac3
Show file tree
Hide file tree
Showing 9 changed files with 266 additions and 89 deletions.
66 changes: 50 additions & 16 deletions benchmarks/two_moons_multicalibrate.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,26 @@
from sklearn.datasets import make_moons
from fortuna.data import DataLoader
import flax.linen as nn
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import optax
from sklearn.datasets import make_moons

from fortuna.conformal import Multicalibrator
from fortuna.prob_model import ProbClassifier, MAPPosteriorApproximator
from fortuna.data import (
DataLoader,
InputsLoader,
)
from fortuna.metric.classification import accuracy
from fortuna.model.mlp import MLP
import flax.linen as nn
from fortuna.prob_model import (
CalibConfig,
CalibMonitor,
FitConfig,
FitMonitor,
FitOptimizer,
CalibConfig,
CalibMonitor,
MAPPosteriorApproximator,
ProbClassifier,
)
from fortuna.metric.classification import accuracy
import optax
import matplotlib.pyplot as plt
from fortuna.data import InputsLoader
import numpy as np

train_data = make_moons(n_samples=5000, noise=0.07, random_state=0)
val_data = make_moons(n_samples=1000, noise=0.07, random_state=1)
Expand Down Expand Up @@ -85,7 +89,14 @@
test_groups = jnp.stack((test_means.argmax(1) == 0, test_means.argmax(1) == 1), axis=1)
values = val_means[:, 1]
test_values = test_means[:, 1]
calib_test_values, status = mc.calibrate(scores=scores, groups=groups, values=values, test_groups=test_groups, test_values=test_values, n_buckets=1000)
calib_test_values, status = mc.calibrate(
scores=scores,
groups=groups,
values=values,
test_groups=test_groups,
test_values=test_values,
n_buckets=1000,
)

plt.figure(figsize=(10, 3))
plt.suptitle("Multivalid calibration of probability that Y=1")
Expand All @@ -109,8 +120,31 @@
plt.semilogy(status["max_calib_errors"])
plt.show()

print("Per-group reweighted avg. squared calib. error before calibration: ", mc.calibration_error(scores=test_scores, groups=test_groups, values=test_means.max(1)))
print("Per-group reweighted avg. squared calib. error after calibration: ", mc.calibration_error(scores=test_scores, groups=test_groups, values=calib_test_values))
print(
"Per-group reweighted avg. squared calib. error before calibration: ",
mc.calibration_error(
scores=test_scores, groups=test_groups, values=test_means.max(1)
),
)
print(
"Per-group reweighted avg. squared calib. error after calibration: ",
mc.calibration_error(
scores=test_scores, groups=test_groups, values=calib_test_values
),
)

print("Mismatch between labels and probs before calibration: ", jnp.mean(jnp.maximum((1 - test_targets) * test_values, test_targets * (1 - test_values))))
print("Mismatch between labels and probs after calibration: ", jnp.mean(jnp.maximum((1 - test_targets) * calib_test_values, test_targets * (1 - calib_test_values))))
print(
"Mismatch between labels and probs before calibration: ",
jnp.mean(
jnp.maximum((1 - test_targets) * test_values, test_targets * (1 - test_values))
),
)
print(
"Mismatch between labels and probs after calibration: ",
jnp.mean(
jnp.maximum(
(1 - test_targets) * calib_test_values,
test_targets * (1 - calib_test_values),
)
),
)
24 changes: 18 additions & 6 deletions examples/multivalid_coverage.pct.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,15 +216,19 @@ def plot_intervals(xx, means, intervals, test_data, method):
from fortuna.conformal.regression.batch_mvp import BatchMVPConformalRegressor
import jax.numpy as jnp

qleft, qright = prob_model.predictive.quantile([0.05 / 2, 1 - 0.05 / 2], calib_inputs_loader)
qleft, qright = prob_model.predictive.quantile(
[0.05 / 2, 1 - 0.05 / 2], calib_inputs_loader
)
scores = jnp.maximum(qleft - calib_targets, calib_targets - qright).squeeze(1)
min_score, max_score = scores.min(), scores.max()
scores = (scores - min_score) / (max_score - min_score)
groups = jnp.stack([g(calib_data[0]) for g in group_fns], axis=1)
test_groups = jnp.stack([g(test_data[0]) for g in group_fns], axis=1)

batchmvp = BatchMVPConformalRegressor()
test_thresholds, status = batchmvp.calibrate(scores=scores, groups=groups, test_groups=test_groups)
test_thresholds, status = batchmvp.calibrate(
scores=scores, groups=groups, test_groups=test_groups
)
test_thresholds = min_score + (max_score - min_score) * test_thresholds

# %% [markdown]
Expand All @@ -241,9 +245,13 @@ def plot_intervals(xx, means, intervals, test_data, method):
# Given the test thresholds, we can find the lower and upper bounds of the conformal intervals by inverting the score function $s(x, y)$ with respect to $y$. This gives $b(x, \tau) = [q_{\frac{\alpha}{2}}(x) - \tau, q_{1 - \frac{\alpha}{2}}(x) + \tau]$, where $\tau$ denotes the thresholds.

# %%
test_qleft, test_qright = prob_model.predictive.quantile([0.05 / 2, 1 - 0.05 / 2], test_inputs_loader)
test_qleft, test_qright = prob_model.predictive.quantile(
[0.05 / 2, 1 - 0.05 / 2], test_inputs_loader
)
test_qleft, test_qright = test_qleft.squeeze(1), test_qright.squeeze(1)
test_batchmvp_intervals = jnp.stack((test_qleft - test_thresholds, test_qright + test_thresholds), axis=1)
test_batchmvp_intervals = jnp.stack(
(test_qleft - test_thresholds, test_qright + test_thresholds), axis=1
)

# %% [markdown]
# We now compute coverage metrics. As expected, `BatchMVP` not only provides a good marginal coverage overall, but also improves coverage on both negative and positive inputs.
Expand All @@ -264,10 +272,14 @@ def plot_intervals(xx, means, intervals, test_data, method):
# Once again, we visualize predictions and estimated intervals.

# %%
xx_qleft, xx_qright = prob_model.predictive.quantile([0.05 / 2, 1 - 0.05 / 2], InputsLoader.from_array_inputs(xx))
xx_qleft, xx_qright = prob_model.predictive.quantile(
[0.05 / 2, 1 - 0.05 / 2], InputsLoader.from_array_inputs(xx)
)
xx_qleft, xx_qright = xx_qleft.squeeze(1), xx_qright.squeeze(1)
xx_groups = jnp.stack([g(xx) for g in group_fns], axis=1)
xx_thresholds = batchmvp.apply_patches(groups=xx_groups)
xx_thresholds = min_score + (max_score - min_score) * xx_thresholds
xx_batchmvp_intervals = jnp.stack((xx_qleft - xx_thresholds, xx_qright + xx_thresholds), axis=1)
xx_batchmvp_intervals = jnp.stack(
(xx_qleft - xx_thresholds, xx_qright + xx_thresholds), axis=1
)
plot_intervals(xx, xx_means, xx_batchmvp_intervals, test_data, "BatchMVP")
2 changes: 1 addition & 1 deletion fortuna/conformal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
from fortuna.conformal.classification.adaptive_prediction import (
AdaptivePredictionConformalClassifier,
)
from fortuna.conformal.multivalid.multicalibrator import Multicalibrator
from fortuna.conformal.classification.batch_mvp import BatchMVPConformalClassifier
from fortuna.conformal.classification.simple_prediction import (
SimplePredictionConformalClassifier,
)
from fortuna.conformal.multivalid.multicalibrator import Multicalibrator
from fortuna.conformal.regression.adaptive_conformal_regressor import (
AdaptiveConformalRegressor,
)
Expand Down
18 changes: 10 additions & 8 deletions fortuna/conformal/classification/batch_mvp.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
from typing import (
List,
)
from typing import List

import jax.numpy as jnp
import numpy as np
Expand Down Expand Up @@ -46,14 +44,18 @@ def conformal_set(
Conformal sets for each input data point.
"""
if class_scores.ndim != 2:
raise ValueError("`class_scores` must bse a 2-dimensional array. "
"The first dimension is over the different inputs. "
"The second dimension is over all the possible classes.")
raise ValueError(
"`class_scores` must bse a 2-dimensional array. "
"The first dimension is over the different inputs. "
"The second dimension is over all the possible classes."
)
if values.ndim != 1:
raise ValueError("`values` must be a 1-dimensional array.")
if class_scores.shape[0] != values.shape[0]:
raise ValueError("The first dimension of `class_scores` and `values` must be over the same input data "
"points.")
raise ValueError(
"The first dimension of `class_scores` and `values` must be over the same input data "
"points."
)
bools = class_scores <= values[:, None]

sizes = np.sum(bools, 0)
Expand Down
57 changes: 41 additions & 16 deletions fortuna/conformal/multivalid/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
import logging
from typing import (
Callable,
Dict,
List,
Optional,
Tuple,
Union,
Dict
)

from jax import vmap
Expand Down Expand Up @@ -62,7 +62,7 @@ def calibrate(
tol: float = 1e-4,
n_buckets: int = None,
n_rounds: int = 1000,
**kwargs
**kwargs,
) -> Union[Dict, Tuple[Array, Dict]]:
"""
Calibrate the model by finding a list of patches to the model that bring the calibration error below a
Expand Down Expand Up @@ -104,14 +104,20 @@ def calibrate(
if tol >= 1:
raise ValueError("`tol` must be smaller than 1.")
if n_rounds < 1:
raise ValueError("`n_rounds` must be at least 1.")
raise ValueError("`n_rounds` must be at least 1.")

if test_groups is None and test_values is not None:
raise ValueError("If `test_values` is provided, `test_groups` must be also provided.")
raise ValueError(
"If `test_values` is provided, `test_groups` must be also provided."
)
if test_values is not None and values is None:
raise ValueError("If `test_values is provided, `values` must also be provided.")
raise ValueError(
"If `test_values is provided, `values` must also be provided."
)
if values is not None and test_groups is not None and test_values is None:
raise ValueError("If `values` and `test_groups` are provided, `test_values` must also be provided.")
raise ValueError(
"If `values` and `test_groups` are provided, `test_values` must also be provided."
)

if values is None:
values_init = jnp.zeros(groups.shape[0])
Expand Down Expand Up @@ -155,9 +161,13 @@ def calibrate(
max_calib_errors.append(calib_error_vg.sum(1).max())
if max_calib_errors[-1] <= tol:
tol_reached = True
logging.info(f"Tolerance satisfied after {t} rounds with {n_buckets} buckets.")
logging.info(
f"Tolerance satisfied after {t} rounds with {n_buckets} buckets."
)
break
if old_calib_errors_vg is not None and jnp.allclose(old_calib_errors_vg, calib_error_vg):
if old_calib_errors_vg is not None and jnp.allclose(
old_calib_errors_vg, calib_error_vg
):
break
old_calib_errors_vg = jnp.copy(calib_error_vg)

Expand All @@ -168,7 +178,13 @@ def calibrate(
groups=groups, values=values, v=vt, g=gt, n_buckets=len(buckets)
)
patch = self._get_patch(
vt=vt, gt=gt, scores=scores, groups=groups, values=values, buckets=buckets, **kwargs
vt=vt,
gt=gt,
scores=scores,
groups=groups,
values=values,
buckets=buckets,
**kwargs,
)
values = self._patch(values=values, patch=patch, bt=bt)

Expand Down Expand Up @@ -221,7 +237,9 @@ def apply_patches(
values = vmap(lambda v: self._round_to_buckets(v, buckets))(values)

for gt, vt, patch in self._patches:
bt = self._get_b(groups=groups, values=values, v=vt, g=gt, n_buckets=self.n_buckets)
bt = self._get_b(
groups=groups, values=values, v=vt, g=gt, n_buckets=self.n_buckets
)
values = self._patch(values=values, bt=bt, patch=patch)
return values

Expand All @@ -231,7 +249,7 @@ def calibration_error(
groups: Array,
values: Array,
n_buckets: int = 10000,
**kwargs
**kwargs,
) -> Array:
"""
The reweighted average squared calibration error :math:`\mu(g) K_2(f, g, \mathcal{D})`.
Expand Down Expand Up @@ -259,15 +277,21 @@ def calibration_error(
return vmap(
lambda g: vmap(
lambda v: self._calibration_error(
v=v, g=g, scores=scores, groups=groups, values=values, n_buckets=n_buckets, **kwargs
v=v,
g=g,
scores=scores,
groups=groups,
values=values,
n_buckets=n_buckets,
**kwargs,
)
)(buckets)
)(jnp.arange(groups.shape[1])).sum(1)

@property
def patches(self):
return self._patches

@property
def n_buckets(self):
return self._n_buckets
Expand All @@ -285,7 +309,7 @@ def _calibration_error(
groups: Array,
values: Array,
n_buckets: int,
**kwargs
**kwargs,
):
pass

Expand Down Expand Up @@ -318,7 +342,7 @@ def _get_patch(
groups: Array,
values: Array,
buckets: Array,
**kwargs
**kwargs,
) -> Array:
pass

Expand All @@ -345,5 +369,6 @@ def _maybe_check(k, v):
if v is not None:
if v.min() < 0 or v.max() > 1:
raise ValueError(f"All elements in `{k}` must be between 0 and 1.")

for k, v in d.items():
_maybe_check(k, v)
Loading

0 comments on commit 1963ac3

Please sign in to comment.