diff --git a/benchmarks/two_moons_multicalibrate.py b/benchmarks/two_moons_multicalibrate.py index fa5bb5cc..baf5205c 100644 --- a/benchmarks/two_moons_multicalibrate.py +++ b/benchmarks/two_moons_multicalibrate.py @@ -1,22 +1,26 @@ -from sklearn.datasets import make_moons -from fortuna.data import DataLoader +import flax.linen as nn import jax.numpy as jnp +import matplotlib.pyplot as plt +import numpy as np +import optax +from sklearn.datasets import make_moons + from fortuna.conformal import Multicalibrator -from fortuna.prob_model import ProbClassifier, MAPPosteriorApproximator +from fortuna.data import ( + DataLoader, + InputsLoader, +) +from fortuna.metric.classification import accuracy from fortuna.model.mlp import MLP -import flax.linen as nn from fortuna.prob_model import ( + CalibConfig, + CalibMonitor, FitConfig, FitMonitor, FitOptimizer, - CalibConfig, - CalibMonitor, + MAPPosteriorApproximator, + ProbClassifier, ) -from fortuna.metric.classification import accuracy -import optax -import matplotlib.pyplot as plt -from fortuna.data import InputsLoader -import numpy as np train_data = make_moons(n_samples=5000, noise=0.07, random_state=0) val_data = make_moons(n_samples=1000, noise=0.07, random_state=1) @@ -85,7 +89,14 @@ test_groups = jnp.stack((test_means.argmax(1) == 0, test_means.argmax(1) == 1), axis=1) values = val_means[:, 1] test_values = test_means[:, 1] -calib_test_values, status = mc.calibrate(scores=scores, groups=groups, values=values, test_groups=test_groups, test_values=test_values, n_buckets=1000) +calib_test_values, status = mc.calibrate( + scores=scores, + groups=groups, + values=values, + test_groups=test_groups, + test_values=test_values, + n_buckets=1000, +) plt.figure(figsize=(10, 3)) plt.suptitle("Multivalid calibration of probability that Y=1") @@ -109,8 +120,31 @@ plt.semilogy(status["max_calib_errors"]) plt.show() -print("Per-group reweighted avg. squared calib. error before calibration: ", mc.calibration_error(scores=test_scores, groups=test_groups, values=test_means.max(1))) -print("Per-group reweighted avg. squared calib. error after calibration: ", mc.calibration_error(scores=test_scores, groups=test_groups, values=calib_test_values)) +print( + "Per-group reweighted avg. squared calib. error before calibration: ", + mc.calibration_error( + scores=test_scores, groups=test_groups, values=test_means.max(1) + ), +) +print( + "Per-group reweighted avg. squared calib. error after calibration: ", + mc.calibration_error( + scores=test_scores, groups=test_groups, values=calib_test_values + ), +) -print("Mismatch between labels and probs before calibration: ", jnp.mean(jnp.maximum((1 - test_targets) * test_values, test_targets * (1 - test_values)))) -print("Mismatch between labels and probs after calibration: ", jnp.mean(jnp.maximum((1 - test_targets) * calib_test_values, test_targets * (1 - calib_test_values)))) +print( + "Mismatch between labels and probs before calibration: ", + jnp.mean( + jnp.maximum((1 - test_targets) * test_values, test_targets * (1 - test_values)) + ), +) +print( + "Mismatch between labels and probs after calibration: ", + jnp.mean( + jnp.maximum( + (1 - test_targets) * calib_test_values, + test_targets * (1 - calib_test_values), + ) + ), +) diff --git a/examples/multivalid_coverage.pct.py b/examples/multivalid_coverage.pct.py index dff1d652..241d0b8e 100644 --- a/examples/multivalid_coverage.pct.py +++ b/examples/multivalid_coverage.pct.py @@ -216,7 +216,9 @@ def plot_intervals(xx, means, intervals, test_data, method): from fortuna.conformal.regression.batch_mvp import BatchMVPConformalRegressor import jax.numpy as jnp -qleft, qright = prob_model.predictive.quantile([0.05 / 2, 1 - 0.05 / 2], calib_inputs_loader) +qleft, qright = prob_model.predictive.quantile( + [0.05 / 2, 1 - 0.05 / 2], calib_inputs_loader +) scores = jnp.maximum(qleft - calib_targets, calib_targets - qright).squeeze(1) min_score, max_score = scores.min(), scores.max() scores = (scores - min_score) / (max_score - min_score) @@ -224,7 +226,9 @@ def plot_intervals(xx, means, intervals, test_data, method): test_groups = jnp.stack([g(test_data[0]) for g in group_fns], axis=1) batchmvp = BatchMVPConformalRegressor() -test_thresholds, status = batchmvp.calibrate(scores=scores, groups=groups, test_groups=test_groups) +test_thresholds, status = batchmvp.calibrate( + scores=scores, groups=groups, test_groups=test_groups +) test_thresholds = min_score + (max_score - min_score) * test_thresholds # %% [markdown] @@ -241,9 +245,13 @@ def plot_intervals(xx, means, intervals, test_data, method): # Given the test thresholds, we can find the lower and upper bounds of the conformal intervals by inverting the score function $s(x, y)$ with respect to $y$. This gives $b(x, \tau) = [q_{\frac{\alpha}{2}}(x) - \tau, q_{1 - \frac{\alpha}{2}}(x) + \tau]$, where $\tau$ denotes the thresholds. # %% -test_qleft, test_qright = prob_model.predictive.quantile([0.05 / 2, 1 - 0.05 / 2], test_inputs_loader) +test_qleft, test_qright = prob_model.predictive.quantile( + [0.05 / 2, 1 - 0.05 / 2], test_inputs_loader +) test_qleft, test_qright = test_qleft.squeeze(1), test_qright.squeeze(1) -test_batchmvp_intervals = jnp.stack((test_qleft - test_thresholds, test_qright + test_thresholds), axis=1) +test_batchmvp_intervals = jnp.stack( + (test_qleft - test_thresholds, test_qright + test_thresholds), axis=1 +) # %% [markdown] # We now compute coverage metrics. As expected, `BatchMVP` not only provides a good marginal coverage overall, but also improves coverage on both negative and positive inputs. @@ -264,10 +272,14 @@ def plot_intervals(xx, means, intervals, test_data, method): # Once again, we visualize predictions and estimated intervals. # %% -xx_qleft, xx_qright = prob_model.predictive.quantile([0.05 / 2, 1 - 0.05 / 2], InputsLoader.from_array_inputs(xx)) +xx_qleft, xx_qright = prob_model.predictive.quantile( + [0.05 / 2, 1 - 0.05 / 2], InputsLoader.from_array_inputs(xx) +) xx_qleft, xx_qright = xx_qleft.squeeze(1), xx_qright.squeeze(1) xx_groups = jnp.stack([g(xx) for g in group_fns], axis=1) xx_thresholds = batchmvp.apply_patches(groups=xx_groups) xx_thresholds = min_score + (max_score - min_score) * xx_thresholds -xx_batchmvp_intervals = jnp.stack((xx_qleft - xx_thresholds, xx_qright + xx_thresholds), axis=1) +xx_batchmvp_intervals = jnp.stack( + (xx_qleft - xx_thresholds, xx_qright + xx_thresholds), axis=1 +) plot_intervals(xx, xx_means, xx_batchmvp_intervals, test_data, "BatchMVP") diff --git a/fortuna/conformal/__init__.py b/fortuna/conformal/__init__.py index 79fa9a2a..3118c74b 100644 --- a/fortuna/conformal/__init__.py +++ b/fortuna/conformal/__init__.py @@ -5,11 +5,11 @@ from fortuna.conformal.classification.adaptive_prediction import ( AdaptivePredictionConformalClassifier, ) -from fortuna.conformal.multivalid.multicalibrator import Multicalibrator from fortuna.conformal.classification.batch_mvp import BatchMVPConformalClassifier from fortuna.conformal.classification.simple_prediction import ( SimplePredictionConformalClassifier, ) +from fortuna.conformal.multivalid.multicalibrator import Multicalibrator from fortuna.conformal.regression.adaptive_conformal_regressor import ( AdaptiveConformalRegressor, ) diff --git a/fortuna/conformal/classification/batch_mvp.py b/fortuna/conformal/classification/batch_mvp.py index 34d23766..52319b1c 100644 --- a/fortuna/conformal/classification/batch_mvp.py +++ b/fortuna/conformal/classification/batch_mvp.py @@ -1,6 +1,4 @@ -from typing import ( - List, -) +from typing import List import jax.numpy as jnp import numpy as np @@ -46,14 +44,18 @@ def conformal_set( Conformal sets for each input data point. """ if class_scores.ndim != 2: - raise ValueError("`class_scores` must bse a 2-dimensional array. " - "The first dimension is over the different inputs. " - "The second dimension is over all the possible classes.") + raise ValueError( + "`class_scores` must bse a 2-dimensional array. " + "The first dimension is over the different inputs. " + "The second dimension is over all the possible classes." + ) if values.ndim != 1: raise ValueError("`values` must be a 1-dimensional array.") if class_scores.shape[0] != values.shape[0]: - raise ValueError("The first dimension of `class_scores` and `values` must be over the same input data " - "points.") + raise ValueError( + "The first dimension of `class_scores` and `values` must be over the same input data " + "points." + ) bools = class_scores <= values[:, None] sizes = np.sum(bools, 0) diff --git a/fortuna/conformal/multivalid/base.py b/fortuna/conformal/multivalid/base.py index a07f8a5a..b1207170 100644 --- a/fortuna/conformal/multivalid/base.py +++ b/fortuna/conformal/multivalid/base.py @@ -2,11 +2,11 @@ import logging from typing import ( Callable, + Dict, List, Optional, Tuple, Union, - Dict ) from jax import vmap @@ -62,7 +62,7 @@ def calibrate( tol: float = 1e-4, n_buckets: int = None, n_rounds: int = 1000, - **kwargs + **kwargs, ) -> Union[Dict, Tuple[Array, Dict]]: """ Calibrate the model by finding a list of patches to the model that bring the calibration error below a @@ -104,14 +104,20 @@ def calibrate( if tol >= 1: raise ValueError("`tol` must be smaller than 1.") if n_rounds < 1: - raise ValueError("`n_rounds` must be at least 1.") + raise ValueError("`n_rounds` must be at least 1.") if test_groups is None and test_values is not None: - raise ValueError("If `test_values` is provided, `test_groups` must be also provided.") + raise ValueError( + "If `test_values` is provided, `test_groups` must be also provided." + ) if test_values is not None and values is None: - raise ValueError("If `test_values is provided, `values` must also be provided.") + raise ValueError( + "If `test_values is provided, `values` must also be provided." + ) if values is not None and test_groups is not None and test_values is None: - raise ValueError("If `values` and `test_groups` are provided, `test_values` must also be provided.") + raise ValueError( + "If `values` and `test_groups` are provided, `test_values` must also be provided." + ) if values is None: values_init = jnp.zeros(groups.shape[0]) @@ -155,9 +161,13 @@ def calibrate( max_calib_errors.append(calib_error_vg.sum(1).max()) if max_calib_errors[-1] <= tol: tol_reached = True - logging.info(f"Tolerance satisfied after {t} rounds with {n_buckets} buckets.") + logging.info( + f"Tolerance satisfied after {t} rounds with {n_buckets} buckets." + ) break - if old_calib_errors_vg is not None and jnp.allclose(old_calib_errors_vg, calib_error_vg): + if old_calib_errors_vg is not None and jnp.allclose( + old_calib_errors_vg, calib_error_vg + ): break old_calib_errors_vg = jnp.copy(calib_error_vg) @@ -168,7 +178,13 @@ def calibrate( groups=groups, values=values, v=vt, g=gt, n_buckets=len(buckets) ) patch = self._get_patch( - vt=vt, gt=gt, scores=scores, groups=groups, values=values, buckets=buckets, **kwargs + vt=vt, + gt=gt, + scores=scores, + groups=groups, + values=values, + buckets=buckets, + **kwargs, ) values = self._patch(values=values, patch=patch, bt=bt) @@ -221,7 +237,9 @@ def apply_patches( values = vmap(lambda v: self._round_to_buckets(v, buckets))(values) for gt, vt, patch in self._patches: - bt = self._get_b(groups=groups, values=values, v=vt, g=gt, n_buckets=self.n_buckets) + bt = self._get_b( + groups=groups, values=values, v=vt, g=gt, n_buckets=self.n_buckets + ) values = self._patch(values=values, bt=bt, patch=patch) return values @@ -231,7 +249,7 @@ def calibration_error( groups: Array, values: Array, n_buckets: int = 10000, - **kwargs + **kwargs, ) -> Array: """ The reweighted average squared calibration error :math:`\mu(g) K_2(f, g, \mathcal{D})`. @@ -259,15 +277,21 @@ def calibration_error( return vmap( lambda g: vmap( lambda v: self._calibration_error( - v=v, g=g, scores=scores, groups=groups, values=values, n_buckets=n_buckets, **kwargs + v=v, + g=g, + scores=scores, + groups=groups, + values=values, + n_buckets=n_buckets, + **kwargs, ) )(buckets) )(jnp.arange(groups.shape[1])).sum(1) - + @property def patches(self): return self._patches - + @property def n_buckets(self): return self._n_buckets @@ -285,7 +309,7 @@ def _calibration_error( groups: Array, values: Array, n_buckets: int, - **kwargs + **kwargs, ): pass @@ -318,7 +342,7 @@ def _get_patch( groups: Array, values: Array, buckets: Array, - **kwargs + **kwargs, ) -> Array: pass @@ -345,5 +369,6 @@ def _maybe_check(k, v): if v is not None: if v.min() < 0 or v.max() > 1: raise ValueError(f"All elements in `{k}` must be between 0 and 1.") + for k, v in d.items(): _maybe_check(k, v) diff --git a/fortuna/conformal/multivalid/batch_mvp.py b/fortuna/conformal/multivalid/batch_mvp.py index 9fd6aeb1..a16281ce 100644 --- a/fortuna/conformal/multivalid/batch_mvp.py +++ b/fortuna/conformal/multivalid/batch_mvp.py @@ -1,8 +1,8 @@ from typing import ( + Dict, Optional, Tuple, Union, - Dict ) from jax import vmap @@ -36,7 +36,7 @@ def calibrate( tol: float = 1e-4, n_buckets: int = None, n_rounds: int = 1000, - coverage: float = 0.95 + coverage: float = 0.95, ) -> Union[Dict, Tuple[Array, Dict]]: """ Calibrate the model by finding a list of patches to the model that bring the calibration error below a @@ -89,7 +89,7 @@ def calibrate( tol=tol, n_buckets=n_buckets, n_rounds=n_rounds, - coverage=coverage + coverage=coverage, ) def calibration_error( @@ -98,14 +98,14 @@ def calibration_error( groups: Array, values: Array, n_buckets: int = 10000, - **kwargs + **kwargs, ) -> Array: return super().calibration_error( scores=scores, groups=groups, values=values, n_buckets=n_buckets, - coverage=self._coverage + coverage=self._coverage, ) def _calibration_error( @@ -116,18 +116,18 @@ def _calibration_error( groups: Array, values: Array, n_buckets: int, - coverage: float = None + coverage: float = None, ): prob_error, prob_b = self._compute_probability_error( v=v, g=g, - delta=0., + delta=0.0, scores=scores, groups=groups, values=values, n_buckets=n_buckets, return_prob_b=True, - coverage=coverage + coverage=coverage, ) return prob_b * prob_error @@ -185,7 +185,7 @@ def _get_patch( groups: Array, values: Array, buckets: Array, - coverage: float = None + coverage: float = None, ) -> Array: return buckets[ jnp.argmin( @@ -198,13 +198,13 @@ def _get_patch( groups=groups, values=values, n_buckets=len(buckets), - coverage=coverage + coverage=coverage, ) - )( - buckets - ) + )(buckets) ) ] - def _patch(self, values: Array, patch: Array, bt: Array, _shift: bool = True) -> Array: + def _patch( + self, values: Array, patch: Array, bt: Array, _shift: bool = True + ) -> Array: return super()._patch(values=values, patch=patch, bt=bt, _shift=_shift) diff --git a/fortuna/conformal/multivalid/multicalibrator.py b/fortuna/conformal/multivalid/multicalibrator.py index 8bcb2763..882b28c3 100644 --- a/fortuna/conformal/multivalid/multicalibrator.py +++ b/fortuna/conformal/multivalid/multicalibrator.py @@ -23,7 +23,7 @@ def _calibration_error( groups: Array, values: Array, n_buckets: int, - **kwargs + **kwargs, ): expectation_error, prob_b = self._compute_expectation_error( v=v, @@ -95,7 +95,7 @@ def _get_patch( groups: Array, values: Array, buckets: Array, - **kwargs + **kwargs, ) -> Array: patch = self._compute_expectation( v=vt, diff --git a/fortuna/conformal/regression/batch_mvp.py b/fortuna/conformal/regression/batch_mvp.py index c27d06f5..ea0dd2e1 100644 --- a/fortuna/conformal/regression/batch_mvp.py +++ b/fortuna/conformal/regression/batch_mvp.py @@ -1,5 +1,5 @@ -from fortuna.conformal.regression.base import ConformalRegressor from fortuna.conformal.multivalid.batch_mvp import BatchMVPConformalMethod +from fortuna.conformal.regression.base import ConformalRegressor class BatchMVPConformalRegressor(BatchMVPConformalMethod, ConformalRegressor): diff --git a/tests/fortuna/test_conformal_methods.py b/tests/fortuna/test_conformal_methods.py index 42bb79b0..b75b7ff6 100755 --- a/tests/fortuna/test_conformal_methods.py +++ b/tests/fortuna/test_conformal_methods.py @@ -1,5 +1,6 @@ import unittest +from jax import random import jax.numpy as jnp import numpy as np @@ -17,12 +18,11 @@ QuantileConformalRegressor, SimplePredictionConformalClassifier, ) +from fortuna.conformal.multivalid.multicalibrator import Multicalibrator from fortuna.data.loader import ( DataLoader, InputsLoader, ) -from fortuna.conformal.multivalid.multicalibrator import Multicalibrator -from jax import random class TestConformalMethods(unittest.TestCase): @@ -328,19 +328,54 @@ def test_batchmvp_regressor(self): test_groups = random.choice(random.PRNGKey(1), 2, shape=(test_size, 3)) test_values = jnp.zeros(test_size) batchmvp = BatchMVPConformalRegressor() - status = batchmvp.calibrate(scores=scores, groups=groups, n_rounds=3, n_buckets=4) - status = batchmvp.calibrate(scores=scores, groups=groups, values=values, n_rounds=3, n_buckets=4) - test_values, status = batchmvp.calibrate(scores=scores, groups=groups, test_groups=test_groups, n_rounds=3, n_buckets=4) + status = batchmvp.calibrate( + scores=scores, groups=groups, n_rounds=3, n_buckets=4 + ) + status = batchmvp.calibrate( + scores=scores, groups=groups, values=values, n_rounds=3, n_buckets=4 + ) + test_values, status = batchmvp.calibrate( + scores=scores, + groups=groups, + test_groups=test_groups, + n_rounds=3, + n_buckets=4, + ) with self.assertRaises(ValueError): - test_values, status = batchmvp.calibrate(scores=scores, groups=groups, values=values, test_groups=test_groups, n_rounds=3, n_buckets=4) + test_values, status = batchmvp.calibrate( + scores=scores, + groups=groups, + values=values, + test_groups=test_groups, + n_rounds=3, + n_buckets=4, + ) with self.assertRaises(ValueError): - test_values, status = batchmvp.calibrate(scores=scores, groups=groups, values=values, test_values=test_values, n_rounds=3, n_buckets=4) + test_values, status = batchmvp.calibrate( + scores=scores, + groups=groups, + values=values, + test_values=test_values, + n_rounds=3, + n_buckets=4, + ) with self.assertRaises(ValueError): - test_values, status = batchmvp.calibrate(scores=scores, groups=groups, test_groups=test_groups, test_values=test_values, n_rounds=3, n_buckets=4) - status = batchmvp.calibrate(scores=scores, groups=groups, n_rounds=3, n_buckets=4) + test_values, status = batchmvp.calibrate( + scores=scores, + groups=groups, + test_groups=test_groups, + test_values=test_values, + n_rounds=3, + n_buckets=4, + ) + status = batchmvp.calibrate( + scores=scores, groups=groups, n_rounds=3, n_buckets=4 + ) test_values = batchmvp.apply_patches(test_groups) test_values = batchmvp.apply_patches(test_groups, test_values) - error = batchmvp.calibration_error(scores=test_scores, groups=test_groups, values=test_values) + error = batchmvp.calibration_error( + scores=test_scores, groups=test_groups, values=test_values + ) def test_batchmvp_classifier(self): size = 10 @@ -351,21 +386,59 @@ def test_batchmvp_classifier(self): test_scores = random.uniform(random.PRNGKey(0), shape=(test_size,)) test_groups = random.choice(random.PRNGKey(1), 2, shape=(test_size, 3)) batchmvp = BatchMVPConformalClassifier() - status = batchmvp.calibrate(scores=scores, groups=groups, n_rounds=3, n_buckets=4) - status = batchmvp.calibrate(scores=scores, groups=groups, values=values, n_rounds=3, n_buckets=4) - test_values, status = batchmvp.calibrate(scores=scores, groups=groups, test_groups=test_groups, n_rounds=3, n_buckets=4) + status = batchmvp.calibrate( + scores=scores, groups=groups, n_rounds=3, n_buckets=4 + ) + status = batchmvp.calibrate( + scores=scores, groups=groups, values=values, n_rounds=3, n_buckets=4 + ) + test_values, status = batchmvp.calibrate( + scores=scores, + groups=groups, + test_groups=test_groups, + n_rounds=3, + n_buckets=4, + ) with self.assertRaises(ValueError): - test_values, status = batchmvp.calibrate(scores=scores, groups=groups, values=values, test_groups=test_groups, n_rounds=3, n_buckets=4) + test_values, status = batchmvp.calibrate( + scores=scores, + groups=groups, + values=values, + test_groups=test_groups, + n_rounds=3, + n_buckets=4, + ) with self.assertRaises(ValueError): - test_values, status = batchmvp.calibrate(scores=scores, groups=groups, values=values, test_values=test_values, n_rounds=3, n_buckets=4) + test_values, status = batchmvp.calibrate( + scores=scores, + groups=groups, + values=values, + test_values=test_values, + n_rounds=3, + n_buckets=4, + ) with self.assertRaises(ValueError): - test_values, status = batchmvp.calibrate(scores=scores, groups=groups, test_groups=test_groups, test_values=test_values, n_rounds=3, n_buckets=4) - status = batchmvp.calibrate(scores=scores, groups=groups, n_rounds=3, n_buckets=4) + test_values, status = batchmvp.calibrate( + scores=scores, + groups=groups, + test_groups=test_groups, + test_values=test_values, + n_rounds=3, + n_buckets=4, + ) + status = batchmvp.calibrate( + scores=scores, groups=groups, n_rounds=3, n_buckets=4 + ) test_values = batchmvp.apply_patches(test_groups) test_values = batchmvp.apply_patches(test_groups, test_values) - error = batchmvp.calibration_error(scores=test_scores, groups=test_groups, values=test_values) + error = batchmvp.calibration_error( + scores=test_scores, groups=test_groups, values=test_values + ) - sets = batchmvp.conformal_set(class_scores=jnp.stack((test_scores, test_scores), axis=1), values=test_values) + sets = batchmvp.conformal_set( + class_scores=jnp.stack((test_scores, test_scores), axis=1), + values=test_values, + ) assert len(sets) == test_size def test_multicalibrator(self): @@ -378,15 +451,46 @@ def test_multicalibrator(self): test_groups = random.choice(random.PRNGKey(1), 2, shape=(test_size, 3)) mc = Multicalibrator() status = mc.calibrate(scores=scores, groups=groups, n_rounds=3, n_buckets=4) - status = mc.calibrate(scores=scores, groups=groups, values=values, n_rounds=3, n_buckets=4) - test_values, status = mc.calibrate(scores=scores, groups=groups, test_groups=test_groups, n_rounds=3, n_buckets=4) + status = mc.calibrate( + scores=scores, groups=groups, values=values, n_rounds=3, n_buckets=4 + ) + test_values, status = mc.calibrate( + scores=scores, + groups=groups, + test_groups=test_groups, + n_rounds=3, + n_buckets=4, + ) with self.assertRaises(ValueError): - test_values, status = mc.calibrate(scores=scores, groups=groups, values=values, test_groups=test_groups, n_rounds=3, n_buckets=4) + test_values, status = mc.calibrate( + scores=scores, + groups=groups, + values=values, + test_groups=test_groups, + n_rounds=3, + n_buckets=4, + ) with self.assertRaises(ValueError): - test_values, status = mc.calibrate(scores=scores, groups=groups, values=values, test_values=test_values, n_rounds=3, n_buckets=4) + test_values, status = mc.calibrate( + scores=scores, + groups=groups, + values=values, + test_values=test_values, + n_rounds=3, + n_buckets=4, + ) with self.assertRaises(ValueError): - test_values, status = mc.calibrate(scores=scores, groups=groups, test_groups=test_groups, test_values=test_values, n_rounds=3, n_buckets=4) + test_values, status = mc.calibrate( + scores=scores, + groups=groups, + test_groups=test_groups, + test_values=test_values, + n_rounds=3, + n_buckets=4, + ) status = mc.calibrate(scores=scores, groups=groups, n_rounds=3, n_buckets=4) test_values = mc.apply_patches(test_groups) test_values = mc.apply_patches(test_groups, test_values) - error = mc.calibration_error(scores=test_scores, groups=test_groups, values=test_values) + error = mc.calibration_error( + scores=test_scores, groups=test_groups, values=test_values + )