-
Notifications
You must be signed in to change notification settings - Fork 46
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add benchmarks for classification and regression tabular datasets (#118)
* edit installation instructions in readme * bump up version * make small change in readme because of publish to pypi error * bump up version * tabular benchmarks * refactor tabular run * add tabular benchmarks * pre-commit
- Loading branch information
1 parent
c51b646
commit 36a0737
Showing
8 changed files
with
2,288 additions
and
8 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
45 changes: 45 additions & 0 deletions
45
...rks/transformers/sagemaker_entrypoints/prob_model_text_classification_config/default.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
defaults: | ||
- task/sentiment | ||
- model/roberta | ||
- method/sgmcmc_ll | ||
- hyperparams/sghmc_ll | ||
|
||
dataset: | ||
base_data_path: ~ | ||
train_relative_path: "" | ||
test_relative_path: "" | ||
validation_relative_path: "" | ||
|
||
|
||
model: | ||
hparams: | ||
tokenizer_max_length: 512 | ||
max_grad_norm: 1 | ||
adam_eps: 0.00000001 | ||
adam_b2: 0.999 | ||
gradient_checkpointing: "true" | ||
save_every_n_steps: 20000 | ||
keep_top_n_checkpoints: 1 | ||
seed: 42 | ||
disable_jit: False | ||
devices: -1 | ||
|
||
sagemaker: | ||
account_id: ~ | ||
iam_role: ~ | ||
entrypoint: "benchmarks/transformers//prob_model_text_classification.py" | ||
instance_type: "ml.g5.2xlarge" | ||
profile: "default" | ||
region: "us-east-1" | ||
job_name_suffix: ~ | ||
metrics: | ||
- {Name: "train_loss_step", Regex: 'loss: ([-+]?(\d+(\.\d*)?|\.\d+))'} | ||
- {Name: "train_accuracy_step", Regex: 'accuracy: ([-+]?(\d+(\.\d*)?|\.\d+))'} | ||
- {Name: "val_loss", Regex: 'val_loss: ([-+]?(\d+(\.\d*)?|\.\d+))'} | ||
- {Name: "val_accuracy", Regex: 'val_accuracy: ([-+]?(\d+(\.\d*)?|\.\d+))'} | ||
- {Name: "ind_accuracy", Regex: 'IND Test accuracy: ([-+]?(\d+(\.\d*)?|\.\d+))'} | ||
- {Name: "ind_ece", Regex: 'IND ECE: ([-+]?(\d+(\.\d*)?|\.\d+))'} | ||
- {Name: "ood_accuracy", Regex: 'OOD Test accuracy: ([-+]?(\d+(\.\d*)?|\.\d+))'} | ||
- {Name: "ood_ece", Regex: 'OOD ECE: ([-+]?(\d+(\.\d*)?|\.\d+))'} | ||
|
||
output_data_path: ~ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,6 @@ | ||
[tool.poetry] | ||
name = "aws-fortuna" | ||
version = "0.1.25" | ||
version = "0.1.26" | ||
description = "A Library for Uncertainty Quantification." | ||
authors = ["Gianluca Detommaso <[email protected]>", "Alberto Gasparin <[email protected]>"] | ||
license = "Apache-2.0" | ||
|