Skip to content

Commit

Permalink
Add benchmarks for classification and regression tabular datasets (#118)
Browse files Browse the repository at this point in the history
* edit installation instructions in readme

* bump up version

* make small change in readme because of publish to pypi error

* bump up version

* tabular benchmarks

* refactor tabular run

* add tabular benchmarks

* pre-commit
  • Loading branch information
gianlucadetommaso authored Aug 16, 2023
1 parent c51b646 commit 36a0737
Show file tree
Hide file tree
Showing 8 changed files with 2,288 additions and 8 deletions.
798 changes: 798 additions & 0 deletions benchmarks/tabular/analysis.py

Large diffs are not rendered by default.

907 changes: 907 additions & 0 deletions benchmarks/tabular/dataset.py

Large diffs are not rendered by default.

532 changes: 532 additions & 0 deletions benchmarks/tabular/run.py

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
defaults:
- task/sentiment
- model/roberta
- method/sgmcmc_ll
- hyperparams/sghmc_ll

dataset:
base_data_path: ~
train_relative_path: ""
test_relative_path: ""
validation_relative_path: ""


model:
hparams:
tokenizer_max_length: 512
max_grad_norm: 1
adam_eps: 0.00000001
adam_b2: 0.999
gradient_checkpointing: "true"
save_every_n_steps: 20000
keep_top_n_checkpoints: 1
seed: 42
disable_jit: False
devices: -1

sagemaker:
account_id: ~
iam_role: ~
entrypoint: "benchmarks/transformers//prob_model_text_classification.py"
instance_type: "ml.g5.2xlarge"
profile: "default"
region: "us-east-1"
job_name_suffix: ~
metrics:
- {Name: "train_loss_step", Regex: 'loss: ([-+]?(\d+(\.\d*)?|\.\d+))'}
- {Name: "train_accuracy_step", Regex: 'accuracy: ([-+]?(\d+(\.\d*)?|\.\d+))'}
- {Name: "val_loss", Regex: 'val_loss: ([-+]?(\d+(\.\d*)?|\.\d+))'}
- {Name: "val_accuracy", Regex: 'val_accuracy: ([-+]?(\d+(\.\d*)?|\.\d+))'}
- {Name: "ind_accuracy", Regex: 'IND Test accuracy: ([-+]?(\d+(\.\d*)?|\.\d+))'}
- {Name: "ind_ece", Regex: 'IND ECE: ([-+]?(\d+(\.\d*)?|\.\d+))'}
- {Name: "ood_accuracy", Regex: 'OOD Test accuracy: ([-+]?(\d+(\.\d*)?|\.\d+))'}
- {Name: "ood_ece", Regex: 'OOD ECE: ([-+]?(\d+(\.\d*)?|\.\d+))'}

output_data_path: ~
3 changes: 1 addition & 2 deletions fortuna/conformal/classification/binary_multicalibrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,7 @@ def calibration_error(
values=probs,
)

@staticmethod
def mean_squared_error(probs: Array, targets: Array) -> Array:
def mean_squared_error(self, probs: Array, targets: Array) -> Array:
return super().mean_squared_error(values=probs, scores=targets)

@staticmethod
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def _get_b(
def _patch(
self, values: Array, patch: Array, bt: Array, ct: Array, _shift: bool = False
) -> Array:
if jnp.any(jnp.isnan(values)) or jnp.any(values.sum(1, keepdims=True) == 0.0):
if jnp.all(~jnp.isnan(values)) and jnp.all(values.sum(1, keepdims=True) != 0.0):
values /= values.sum(1, keepdims=True)
return super()._patch(values=values, patch=patch, bt=bt, ct=ct, _shift=_shift)

Expand Down
7 changes: 3 additions & 4 deletions fortuna/metric/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,9 @@ def compute_counts_confs_accs(
Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]
Number of inputs per bin, average confidence score per bin and average accuracy per bin.
"""
if probs.ndim != 2:
raise ValueError("""`probs` must be a two-dimensional array.""")
thresholds = jnp.linspace(1 / probs.shape[1], 1, 10)
probs = probs.max(1)
if probs.ndim == 2:
probs = probs.max(-1)
thresholds = jnp.linspace(1 / len(probs), 1, 10)
probs = jnp.array(probs)
indices = [jnp.where(probs <= thresholds[0])[0]]
indices += [
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "aws-fortuna"
version = "0.1.25"
version = "0.1.26"
description = "A Library for Uncertainty Quantification."
authors = ["Gianluca Detommaso <[email protected]>", "Alberto Gasparin <[email protected]>"]
license = "Apache-2.0"
Expand Down

0 comments on commit 36a0737

Please sign in to comment.