Skip to content

Commit

Permalink
Merge pull request #618 from mlcommons/scoring_fixes
Browse files Browse the repository at this point in the history
Scoring updates
  • Loading branch information
priyakasimbeg authored Feb 1, 2024
2 parents 8108c00 + c0e1aad commit ff3c9b0
Show file tree
Hide file tree
Showing 11 changed files with 532 additions and 73 deletions.
2 changes: 1 addition & 1 deletion DOCUMENTATION.md
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ The held-out workloads function similarly to a holdout test set discouraging sub

Modifications could, for example, include changing the number of layers or units (drawn from an interval), swapping the activation function (drawn from a set of applicable functions), or using different data augmentations (drawn from a list of possible pre-processing steps). The sample space should be wide enough to discourage submitters from simply trying them all out, but at the same time should be restricted enough to produce realistic workloads with acceptable achievable performances.

In the first iteration of this benchmark, we manually designed three different workloads variants for each fixed workload. The variants are designed such that they achieve a comparable performance to the fixed workload and that they might require different hyperparameters to achieve this performance. After the submission deadline, one held-out workload will be sampled for each fixed workload.
In the first iteration of this benchmark, we manually designed three different workloads variants for each fixed workload. The variants are designed such that they achieve a comparable performance to the fixed workload and that they might require different hyperparameters to achieve this performance. After the submission deadline, one held-out workload will be sampled for each dataset.

Our scoring procedure uses the held-out workloads only to penalize submissions that can't handle the introduced modifications (see the [Scoring](#scoring) section for further details).

Expand Down
42 changes: 40 additions & 2 deletions GETTING_STARTED.md
Original file line number Diff line number Diff line change
Expand Up @@ -336,11 +336,49 @@ docker exec -it <container_id> /bin/bash
```

## Score your Submission
To score your submission we will score over all workloads, held-out workloads and studies as described in the rules.
We will sample 1 held-out workload per dataset for a total of 6 held-out workloads and will use the sampled
held-out workloads in the scoring criteria for the matching base workloads.
In other words, the total number of runs expected for official scoring is:
- for external ruleset (8 (workloads) + 6 (held-out workloads)) x 5 (studies) x 5 (trials)
- for internal ruleset (8 (workloads) + 6 (held-out workloads)) x 5 (studies)

To produce performance profile and performance table:


### Running workloads
To run workloads for scoring you may specify a "virtual" list of held-out workloads. It is important
to note that the official set of held-out workloads will be sampled by the competition organizers during scoring time.

An example config for held-out workloads is stored in `scoring/held_workloads_example.json`.
To generate a new sample of held out workloads run:

```bash
python3 generate_held_out_workloads.py --seed <optional_rng_seed> --output_filename <output_filename>
```

To run a number of studies and trials over all workload using Docker containers for each run:

```bash
python scoring/run_workloads.py \
--framework <framework> \
--experiment_name <experiment_name> \
--docker_image_url <docker_image_url> \
--submission_path <sumbission_path> \
--tuning_search_space <submission_path> \
--held_out_workloads_config_path held_out_workloads_example.json \
--num_studies <num_studies>
--seed <rng_seed>
```

Note that to run the above script you will need the minimum jax_cpu and pytorch_cpu installations of the algorithmic-efficiency package.

During submission development, it might be useful to do faster, approximate scoring (e.g. without 5 different s
tudies or when some trials are missing) so the scoring scripts allow some flexibility. To simulate official scoring,
pass the `--strict=True` flag in score_submission.py. To get the raw scores and performance profiles of group of
submissions or single submission:

```bash
python3 scoring/score_submission.py --experiment_path=<path_to_experiment_dir> --output_dir=<output_dir>
python score_submissions.py --submission_directory <directory_with_submissions> --output_dir <output_dir> --compute_performance_profiles
```

We provide the scores and performance profiles for the [paper baseline algorithms](/reference_algorithms/paper_baselines/) in the "Baseline Results" section in [Benchmarking Neural Network Training Algorithms](https://arxiv.org/abs/2306.07179).
Expand Down
40 changes: 40 additions & 0 deletions docker/scripts/startup.sh
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ function usage() {
from internal GCP bucket.
-i | --internal_contributor: If true, allow rsync of data and transfer of experiment results
with GCP project.
--num_tuning_trials Number of tuning trials for externally tuned ruleset submission.
--hparam_start_index Should be > 0 and < num_tuning_trials - 1.
--hparam_end_index Should be > 0 and < num_tuning_trials - 1.
--rng_seed RNG seed to pass to workload submission_runner.
--traindiffs_test: If true, ignore all other options and run the traindiffs test.
USAGE
exit 1
Expand Down Expand Up @@ -106,6 +110,22 @@ while [ "$1" != "" ]; do
shift
HOME_DIR=$1
;;
--num_tuning_trials)
shift
NUM_TUNING_TRIALS=$1
;;
--hparam_start_index)
shift
HPARAM_START_INDEX=$1
;;
--hparam_end_index)
shift
HPARAM_END_INDEX=$1
;;
--rng_seed)
shift
RNG_SEED=$1
;;
*)
usage
exit 1
Expand Down Expand Up @@ -193,6 +213,22 @@ if [[ ! -z ${SUBMISSION_PATH+x} ]]; then
MAX_STEPS_FLAG="--max_global_steps=${MAX_GLOBAL_STEPS}"
fi

if [[ ! -z ${NUM_TUNING_TRIALS+x} ]]; then
NUM_TUNING_TRIALS_FLAG="--num_tuning_trials=${NUM_TUNING_TRIALS}"
fi

if [[ ! -z ${HPARAM_START_INDEX+x} ]]; then
HPARAM_START_INDEX_FLAG="--hparam_start_index=${HPARAM_START_INDEX}"
fi

if [[ ! -z ${HPARAM_END_INDEX+x} ]]; then
HPARAM_END_INDEX_FLAG="--hparam_end_index=${HPARAM_END_INDEX}"
fi

if [[ ! -z ${RNG_SEED+x} ]]; then
RNG_SEED_FLAG="--rng_seed=${RNG_SEED}"
fi

# Define special flags for imagenet and librispeech workloads
if [[ ${DATASET} == "imagenet" ]]; then
SPECIAL_FLAGS="--imagenet_v2_data_dir=${DATA_DIR}"
Expand All @@ -217,6 +253,10 @@ if [[ ! -z ${SUBMISSION_PATH+x} ]]; then
--experiment_name=${EXPERIMENT_NAME} \
--overwrite=${OVERWRITE} \
--save_checkpoints=${SAVE_CHECKPOINTS} \
${NUM_TUNING_TRIALS_FLAG} \
${HPARAM_START_INDEX_FLAG} \
${HPARAM_END_INDEX_FLAG} \
${RNG_SEED_FLAG} \
${MAX_STEPS_FLAG} \
${SPECIAL_FLAGS} \
${TORCH_COMPILE_FLAG} 2>&1 | tee -a ${LOG_FILE}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
},
{
"dropout_rate": 0.0,
"label_smoothing": 0.0,
"learning_rate": 0.001308209823469072,
"one_minus_beta1": 0.02686663061,
"beta2": 0.9981232922116359,
Expand All @@ -27,6 +28,7 @@
},
{
"dropout_rate": 0.0,
"label_smoothing": 0.0,
"learning_rate": 0.004958460849689891,
"one_minus_beta1": 0.13625575743,
"beta2": 0.6291854735396584,
Expand All @@ -35,6 +37,7 @@
},
{
"dropout_rate": 0.1,
"label_smoothing": 0.0,
"learning_rate": 0.0017486387539278373,
"one_minus_beta1": 0.06733926164,
"beta2": 0.9955159689799007,
Expand Down
67 changes: 67 additions & 0 deletions scoring/generate_held_out_workloads.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import json
import os
import struct

from absl import app
from absl import flags
from absl import logging
import numpy as np

flags.DEFINE_integer('held_out_workloads_seed',
None,
'Random seed for scoring.')
flags.DEFINE_string('output_filename',
'held_out_workloads.json',
'Path to file to record sampled held_out workloads.')
flags.DEFINE_string('framework', 'jax', 'JAX or PyTorch')
FLAGS = flags.FLAGS

HELD_OUT_WORKLOADS = {
'librispeech': [
'librispeech_conformer_attention_temperature',
'librispeech_conformer_layernorm',
'librispeech_conformer_gelu'
],
'imagenet': [
'imagenet_resnet_silu',
'imagenet_resnet_gelu',
'imagenet_resnet_large_bn_init',
'imagenet_vit_gelu',
'imagenet_vit_post_ln',
'imagenet_vit_map'
],
'ogbg': ['ogbg_gelu', 'ogbg_silu', 'ogbg_model_size'],
'wmt': ['wmt_post_ln', 'wmt_attention_temp', 'wmt_glu_tanh'],
'fastmri': ['fastmri_model_size', 'fastmri_tanh', 'fastmri_layernorm'],
'criteo1tb': [
'criteo1tb_layernorm', 'criteo1tb_embed_init', 'criteo1tb_resnet'
]
}


def save_held_out_workloads(held_out_workloads, filename):
with open(filename, "w") as f:
json.dump(held_out_workloads, f)


def main(_):
rng_seed = FLAGS.held_out_workloads_seed
output_filename = FLAGS.output_filename

if not rng_seed:
rng_seed = struct.unpack('I', os.urandom(4))[0]

logging.info('Using RNG seed %d', rng_seed)
rng = np.random.default_rng(rng_seed)

sampled_held_out_workloads = []
for k, v in HELD_OUT_WORKLOADS.items():
sampled_index = rng.integers(len(v))
sampled_held_out_workloads.append(v[sampled_index])

logging.info(f"Sampled held-out workloads: {sampled_held_out_workloads}")
save_held_out_workloads(sampled_held_out_workloads, output_filename)


if __name__ == '__main__':
app.run(main)
1 change: 1 addition & 0 deletions scoring/held_out_workloads_example.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
["librispeech_conformer_gelu", "imagenet_resnet_silu", "ogbg_gelu", "wmt_post_ln", "fastmri_model_size", "criteo1tb_layernorm"]
Loading

0 comments on commit ff3c9b0

Please sign in to comment.