Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CV+ Adaptive & Simple Prediction Conformal Classifier #136

Merged
merged 56 commits into from
Sep 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
56 commits
Select commit Hold shift + click to select a range
52e96ea
edit installation instructions in readme
gianlucadetommaso May 15, 2023
5e0076d
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso May 15, 2023
4c7fd28
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso May 15, 2023
6cb6581
bump up version
gianlucadetommaso May 15, 2023
1b39780
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso May 16, 2023
cb2b49a
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso May 16, 2023
14e3ca4
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso May 25, 2023
580067d
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso May 27, 2023
048ef09
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Jun 2, 2023
ad542a4
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Jun 12, 2023
41417c1
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Jun 12, 2023
64be374
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Jun 14, 2023
a2d0f34
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Jun 14, 2023
66bba06
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Jun 15, 2023
911aa82
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Jun 15, 2023
01f959b
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Jun 15, 2023
79f8dca
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Jun 15, 2023
4dea50f
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Jun 21, 2023
1ced008
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Jul 18, 2023
6992692
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Jul 18, 2023
b2540c1
make small change in readme because of publish to pypi error
gianlucadetommaso Jul 18, 2023
2362998
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Jul 18, 2023
6e030f2
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Jul 25, 2023
9bd6f67
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Jul 25, 2023
c5bc94f
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Jul 25, 2023
d3ab46b
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Jul 26, 2023
0e2aca5
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Jul 26, 2023
9520273
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Jul 30, 2023
e9c4108
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Jul 30, 2023
bc64a01
bump up version
gianlucadetommaso Jul 30, 2023
25072da
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Jul 30, 2023
e27b378
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Jul 30, 2023
a175e16
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Aug 1, 2023
6e202f1
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Aug 1, 2023
635e7c9
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Aug 9, 2023
8e23b32
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Aug 16, 2023
f5efef8
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Aug 24, 2023
958b245
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Aug 24, 2023
577d169
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Aug 28, 2023
69a454e
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Aug 30, 2023
6e880ba
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Aug 30, 2023
f606545
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Sep 11, 2023
63e09bb
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Sep 11, 2023
b2402b5
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Sep 12, 2023
591d842
refactor tabular analysis of benchmarks
gianlucadetommaso Sep 13, 2023
3dcf217
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Sep 13, 2023
d1b5b4a
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Sep 18, 2023
b4c161e
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Sep 21, 2023
744dff1
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Sep 21, 2023
a22f97f
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Sep 24, 2023
fffdd76
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Sep 26, 2023
c23d16d
Merge branch 'main' of https://github.com/awslabs/fortuna
gianlucadetommaso Sep 26, 2023
411a88b
add CV+ Simple & Adaptive Prediction Conformal Classifiers
gianlucadetommaso Sep 27, 2023
270eb52
pre-commit and bump up version
gianlucadetommaso Sep 27, 2023
3a3f8a5
make error explicit in notebook
gianlucadetommaso Sep 27, 2023
acf08e7
add error explitly in doc and notebooks
gianlucadetommaso Sep 27, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion docs/source/usage_modes/flax_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,8 @@ but a new one could be used.
conformal_sets = AdaptivePredictionConformalClassifier().conformal_set(
val_probs=calib_means,
test_probs=test_means,
val_targets=calib_targets
val_targets=calib_targets,
error=0.05
)

.. _flax_models_regression:
Expand Down
3 changes: 2 additions & 1 deletion docs/source/usage_modes/model_outputs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,8 @@ and :code:`val_targets` to be the corresponding validation target variables.
conformal_sets = AdaptivePredictionConformalClassifier().conformal_set(
val_probs=val_means,
test_probs=test_means,
val_targets=val_targets
val_targets=val_targets,
error=0.05
)

.. _model_outputs_regression:
Expand Down
3 changes: 2 additions & 1 deletion docs/source/usage_modes/uncertainty_estimates.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ Please check :class:`~fortuna.conformal.classification.adaptive_prediction.Adapt
conformal_sets = AdaptivePredictionConformalClassifier().conformal_set(
val_probs=val_probs,
test_probs=test_probs,
val_targets=val_targets
val_targets=val_targets,
error=0.05
)

You should usually expect your test predictions to be included in the conformal sets, as they contain the most probable
Expand Down
1 change: 1 addition & 0 deletions examples/mnist_classification.pct.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ def download(split_range, shuffle=False):
val_probs=val_means,
test_probs=test_means,
val_targets=val_data_loader.to_array_targets(),
error=0.05
)

# %% [markdown]
Expand Down
1 change: 1 addition & 0 deletions examples/mnist_classification_sghmc.pct.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ def download(split_range, shuffle=False):
val_probs=val_means,
test_probs=test_means,
val_targets=val_data_loader.to_array_targets(),
error=0.05
)

# %% [markdown]
Expand Down
2 changes: 2 additions & 0 deletions fortuna/conformal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
)
from fortuna.conformal.classification.adaptive_prediction import (
AdaptivePredictionConformalClassifier,
CVPlusAdaptivePredictionConformalClassifier,
)
from fortuna.conformal.classification.maxcovfixprec_binary_classfication import (
MaxCoverageFixedPrecisionBinaryClassificationCalibrator,
)
from fortuna.conformal.classification.simple_prediction import (
CVPlusSimplePredictionConformalClassifier,
SimplePredictionConformalClassifier,
)
from fortuna.conformal.multivalid.iterative.classification.binary_multicalibrator import (
Expand Down
146 changes: 26 additions & 120 deletions fortuna/conformal/classification/adaptive_prediction.py
Original file line number Diff line number Diff line change
@@ -1,134 +1,40 @@
from typing import (
List,
Optional,
)

from jax import vmap
import jax.numpy as jnp
import numpy as np

from fortuna.conformal.classification.base import ConformalClassifier
from fortuna.conformal.classification.base import (
CVPlusConformalClassifier,
SplitConformalClassifier,
)
from fortuna.typing import Array


class AdaptivePredictionConformalClassifier(ConformalClassifier):
def score(
self,
val_probs: Array,
val_targets: Array,
) -> jnp.ndarray:
"""
Compute score function.

Parameters
----------
val_probs: Array
A two-dimensional array of class probabilities for each validation data point.
val_targets: Array
A one-dimensional array of validation target variables.

Returns
-------
jnp.ndarray
The conformal scores.
"""
if val_probs.ndim != 2:
raise ValueError(
"""`val_probs` must be a two-dimensional array. The first dimension is over the validation
inputs. The second is over the classes."""
)
@vmap
def _score_fn(probs: Array, perm: Array, inv_perm: Array, targets: Array):
return jnp.cumsum(probs[perm])[inv_perm[targets]]

perms = jnp.argsort(val_probs, axis=1)[:, ::-1]
inv_perms = jnp.argsort(perms, axis=1)

@vmap
def score_fn(prob, perm, inv_perm, target):
sorted_prob = prob[perm]
return jnp.cumsum(sorted_prob)[inv_perm[target]]
def score_fn(
probs: Array,
targets: Array,
):
perms = jnp.argsort(probs, axis=1)[:, ::-1]
inv_perms = jnp.argsort(perms, axis=1)
return _score_fn(probs, perms, inv_perms, targets)

return score_fn(val_probs, perms, inv_perms, val_targets)

def quantile(
class AdaptivePredictionConformalClassifier(SplitConformalClassifier):
def score_fn(
self,
val_probs: Array,
val_targets: Array,
error: float = 0.05,
scores: Optional[Array] = None,
) -> Array:
"""
Compute a quantile of the scores.
probs: Array,
targets: Array,
):
return score_fn(probs=probs, targets=targets)

Parameters
----------
val_probs: Array
A two-dimensional array of class probabilities for each validation data point.
val_targets: Array
A one-dimensional array of validation target variables.
error: float
Coverage error. This must be a scalar between 0 and 1, extremes included.
scores: Optional[Array]
The conformal scores. This should be the output of
:meth:`~fortuna.conformal.classification.adaptive_prediction.AdaptivePredictionConformalClassifier.score`.

Returns
-------
float
The conformal quantiles.
"""
if error < 0 or error > 1:
raise ValueError("""`error` must be a scalar between 0 and 1.""")

if scores is None:
scores = self.score(val_probs, val_targets)
n = scores.shape[0]
return jnp.quantile(scores, jnp.ceil((n + 1) * (1 - error)) / n)

def conformal_set(
class CVPlusAdaptivePredictionConformalClassifier(CVPlusConformalClassifier):
def score_fn(
self,
val_probs: Array,
test_probs: Array,
val_targets: Array,
error: float = 0.05,
quantile: Optional[float] = None,
) -> List[List[int]]:
"""
Coverage set of each of the test inputs, at the desired coverage error.

Parameters
----------
val_probs: Array
A two-dimensional array of class probabilities for each validation data point.
test_probs: Array
A two-dimensional array of class probabilities for each test data point.
val_targets: Array
A one-dimensional array of validation target variables.
error: float
The coverage error. This must be a scalar between 0 and 1, extremes included.
quantile: Optional[float]
Conformal quantiles. This should be the output of
:meth:`~fortuna.conformal.classification.adaptive_prediction.AdaptivePredictionConformalClassifier.quantile`.

Returns
-------
List[List[int, ...]]
The coverage sets.
"""
if test_probs.ndim != 2:
raise ValueError(
"""`test_probs` must be a two-dimensional array. The first dimension is over the validation
inputs. The second is over the classes."""
)

if quantile is None:
quantile = self.quantile(val_probs, val_targets, error)
test_perms = jnp.argsort(test_probs, axis=1)[:, ::-1]
test_sorted_probs = vmap(lambda prob, perm: prob[perm])(test_probs, test_perms)
sizes = (
(test_sorted_probs.cumsum(axis=1) > quantile).astype("int32").argmax(axis=1)
)

sets = np.zeros(len(sizes), dtype=object)
for s in jnp.unique(sizes):
idx = jnp.where(sizes == s)[0]
sets[idx] = test_perms[idx, : s + 1].tolist()
return sets.tolist()
probs: Array,
targets: Array,
):
return score_fn(probs=probs, targets=targets)
105 changes: 104 additions & 1 deletion fortuna/conformal/classification/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
from typing import List
import abc
from typing import (
List,
Tuple,
)

from jax import vmap
import jax.numpy as jnp
import numpy as np

from fortuna.typing import Array

Expand All @@ -27,3 +33,100 @@ def is_in(self, values: Array, conformal_sets: List) -> Array:
An array of ones or zero, indicating whether the values lie within their respective conformal sets.
"""
return jnp.array([v in s for v, s in zip(values.tolist(), conformal_sets)])

@abc.abstractmethod
def score_fn(
self,
probs: Array,
targets: Array,
):
pass

@staticmethod
def _get_conformal_sets_from_scores(
val_scores: Array,
test_scores: Array,
error: float,
) -> List[List[int]]:
conds = jnp.sum(val_scores[:, None, None] > test_scores[None], axis=0) < (
1 - error
) * (len(val_scores) + 1)
sizes = conds.sum(1)

sets = np.zeros(len(test_scores), dtype=object)
for us in jnp.unique(sizes):
idx = jnp.where(sizes == us)[0]
if us == 0:
sets[idx] = [len(idx) * []]
else:
sets[idx] = np.where(conds[idx])[1].reshape(-1, us).tolist()

return sets.tolist()

@abc.abstractmethod
def get_scores(self, *args, **kwargs) -> Tuple[Array, Array]:
pass


class SplitConformalClassifier(ConformalClassifier, abc.ABC):
def get_scores(
self, val_probs: Array, val_targets: Array, test_probs: Array
) -> Tuple[Array, Array]:
val_scores = self.score_fn(val_probs, val_targets)
test_scores = vmap(
lambda i: self.score_fn(
test_probs, i * jnp.ones(len(test_probs), dtype="int32")
),
out_axes=1,
)(jnp.arange(val_probs.shape[1]))
return val_scores, test_scores

def conformal_set(
self, val_probs: Array, val_targets: Array, test_probs: Array, error: float
) -> List[List[int]]:
val_scores, test_scores = self.get_scores(
val_probs=val_probs, val_targets=val_targets, test_probs=test_probs
)
return super()._get_conformal_sets_from_scores(
val_scores=val_scores, test_scores=test_scores, error=error
)


class CVPlusConformalClassifier(ConformalClassifier):
def conformal_set(
self,
cross_val_probs: List[Array],
cross_val_targets: List[Array],
cross_test_probs: List[Array],
error: float,
) -> List[List[int]]:
val_scores, test_scores = self.get_scores(
cross_val_probs=cross_val_probs,
cross_val_targets=cross_val_targets,
cross_test_probs=cross_test_probs,
)
return super()._get_conformal_sets_from_scores(
val_scores=val_scores, test_scores=test_scores, error=error
)

def get_scores(
self,
cross_val_probs: List[Array],
cross_val_targets: List[Array],
cross_test_probs: List[Array],
) -> Tuple[Array, Array]:
val_scores, test_scores = [], []
for val_probs, val_targets, test_probs in zip(
cross_val_probs, cross_val_targets, cross_test_probs
):
val_scores.append(self.score_fn(val_probs, val_targets))
test_scores.append(
vmap(
lambda i: self.score_fn(
test_probs, i * jnp.ones(len(test_probs), dtype="int32")
),
out_axes=1,
)(jnp.arange(cross_val_probs[0].shape[1]))
)

return jnp.concatenate(val_scores), jnp.concatenate(test_scores, axis=0)
Loading
Loading