Skip to content

Commit

Permalink
scale calib error and remove rounding
Browse files Browse the repository at this point in the history
  • Loading branch information
gianlucadetommaso committed Oct 27, 2023
1 parent c8ca7e6 commit c294a00
Show file tree
Hide file tree
Showing 8 changed files with 7 additions and 25 deletions.
2 changes: 1 addition & 1 deletion examples/multivalid_coverage.pct.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import numpy as np


def generate_data(n_data: int, sigma1=0.03, sigma2=0.5, seed: int = 43):
def generate_data(n_data: int, sigma1=0.03, sigma2=0.5, seed: int = 42):
rng = np.random.default_rng(seed=seed)
x = np.concatenate(
[
Expand Down
4 changes: 0 additions & 4 deletions fortuna/conformal/multivalid/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,6 @@ def n_buckets(self, n_buckets):
def _get_buckets(n_buckets: int):
return jnp.linspace(0, 1, n_buckets)

@staticmethod
def _round_to_buckets(v: Array, buckets: Array):
return buckets[jnp.argmin(jnp.abs(v - buckets))]

@staticmethod
def _maybe_check_values(
values: Optional[Array], test_values: Optional[Array] = None
Expand Down
8 changes: 2 additions & 6 deletions fortuna/conformal/multivalid/iterative/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,6 @@ def calibrate(

self.n_buckets = n_buckets
buckets = self._get_buckets(n_buckets)
values = vmap(lambda v: self._round_to_buckets(v, buckets))(values)

self._check_bucket_types(bucket_types)
taus = self._get_bucket_type_indices(bucket_types)
Expand Down Expand Up @@ -327,9 +326,6 @@ def apply_patches(
groups = self._init_groups(groups, values.shape[0])
self._maybe_check_groups(groups)

buckets = self._get_buckets(n_buckets=self.n_buckets)
values = vmap(lambda v: self._round_to_buckets(v, buckets))(values)

for taut, gt, vt, ct, patch in self._patches:
bt = self._get_b(
groups=groups,
Expand All @@ -348,7 +344,7 @@ def calibration_error(
scores: Array,
groups: Optional[Array] = None,
values: Optional[Array] = None,
n_buckets: int = 10000,
n_buckets: int = 10,
**kwargs,
) -> Array:
"""
Expand Down Expand Up @@ -383,7 +379,6 @@ def calibration_error(
values = self._maybe_init_values(values, scores.shape[0])

buckets = self._get_buckets(n_buckets)
values = vmap(lambda v: self._round_to_buckets(v, buckets))(values)

error, b = vmap(
lambda g: vmap(
Expand All @@ -403,6 +398,7 @@ def calibration_error(
)(buckets)
)(jnp.arange(groups.shape[1]))
error = error.sum(1)
error /= groups.mean(0)[:, None]

if n_dims == 1:
error = error.squeeze(1)
Expand Down
2 changes: 1 addition & 1 deletion fortuna/conformal/multivalid/iterative/batch_mvp.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def calibration_error(
scores: Array,
groups: Optional[Array] = None,
thresholds: Optional[Array] = None,
n_buckets: int = 10000,
n_buckets: int = 10,
**kwargs,
) -> Array:
return super().calibration_error(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def calibration_error(
targets: Array,
groups: Optional[Array] = None,
probs: Optional[Array] = None,
n_buckets: int = 10000,
n_buckets: int = 10,
**kwargs,
) -> Array:
return super().calibration_error(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def calibration_error(
targets: Array,
groups: Optional[Array] = None,
probs: Optional[Array] = None,
n_buckets: int = 10000,
n_buckets: int = 10,
**kwargs,
) -> Array:
return super().calibration_error(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,6 @@ def mean_squared_error(self, probs: Array, targets: Array) -> Array:
values=probs, scores=self._get_scores(targets)
)

@staticmethod
def _round_to_buckets(v: Array, buckets: Array) -> Array:
def _fun(_v):
return buckets[jnp.argmin(jnp.abs(_v - buckets))]

if len(v.shape):
return vmap(_fun)(v)
return _fun(v)

def _get_scores(self, targets: Array) -> Array:
self._check_targets(targets)
scores = []
Expand Down
3 changes: 1 addition & 2 deletions fortuna/conformal/multivalid/one_shot/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,6 @@ def apply_patches(
return values

buckets = self._get_buckets(self.n_buckets)
values = vmap(lambda v: self._round_to_buckets(v, buckets))(values)
unique_values = jnp.unique(values)

n_dims = 1 if values.ndim == 1 else values.shape[1]
Expand Down Expand Up @@ -138,7 +137,7 @@ def _patch(
patched_values = jnp.copy(values)
for i, v in enumerate(unique_values):
for c in range(b.shape[2]):
idx_v = jnp.where(buckets == v)[0][0]
idx_v = jnp.argmin(jnp.abs(buckets - v))
if values.ndim == 1:
patched_values = patched_values.at[b[:, i, c]].set(
self._patches[idx_v, c]
Expand Down

0 comments on commit c294a00

Please sign in to comment.