Skip to content

Commit

Permalink
Merge pull request #536 from mlcommons/dev
Browse files Browse the repository at this point in the history
dev -> main
  • Loading branch information
priyakasimbeg authored Oct 9, 2023
2 parents ddf5e14 + 4131232 commit e19dacf
Show file tree
Hide file tree
Showing 11 changed files with 106 additions and 64 deletions.
24 changes: 22 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@


## Installation
You can install this package and dependences in a [python virtual environment](#virtual-environment) or use a [Docker container](#install-in-docker) (recommended).
You can install this package and dependences in a [python virtual environment](#virtual-environment) or use a [Docker/Singularity/Apptainer container](#install-in-docker) (recommended).

*TL;DR to install the Jax version for GPU run:*

Expand Down Expand Up @@ -89,7 +89,8 @@ pip3 install -e '.[full]'
</details>

## Docker
We recommend using a Docker container to ensure a similar environment to our scoring and testing environments.
We recommend using a Docker container to ensure a similar environment to our scoring and testing environments.
Alternatively, a Singularity/Apptainer container can also be used (see instructions below).


**Prerequisites for NVIDIA GPU set up**: You may have to install the NVIDIA Container Toolkit so that the containers can locate the NVIDIA drivers and GPUs.
Expand Down Expand Up @@ -133,6 +134,25 @@ To use the Docker container as an interactive virtual environment, you can run a
### Running Docker Container (End-to-end)
To run a submission end-to-end in a containerized environment see [Getting Started Document](./getting_started.md#run-your-submission-in-a-docker-container).

### Using Singularity/Apptainer instead of Docker
Since many compute clusters don't allow the usage of Docker due to securtiy concerns and instead encourage the use of [Singularity/Apptainer](https://github.com/apptainer/apptainer) (formerly Singularity, now called Apptainer), we also provide instructions on how to build an Apptainer container based on the here provided Dockerfile.

To convert the Dockerfile into an Apptainer definition file, we will use [spython](https://github.com/singularityhub/singularity-cli):
```bash
pip3 install spython
cd algorithmic-efficiency/docker
spython recipe Dockerfile &> Singularity.def
```
Now we can build the Apptainer image by running
```bash
singularity build --fakeroot <singularity_image_name>.sif Singularity.def
```
To start a shell session with GPU support (by using the `--nv` flag), we can run
```bash
singularity shell --nv <singularity_image_name>.sif
```
Similarly to Docker, Apptainer allows you to bind specific paths on the host system and the container by specifying the `--bind` flag, as explained [here](https://docs.sylabs.io/guides/3.7/user-guide/bind_paths_and_mounts.html).

# Getting Started
For instructions on developing and scoring your own algorithm in the benchmark see [Getting Started Document](./getting_started.md).
## Running a workload
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def _build_input_queue(
}

padded_batch = data_utils.shard_and_maybe_pad_np(
numpy_batch, padding_value=1.0, global_batch_size=global_batch_size)
numpy_batch, padding_value=1.0)
yield padded_batch

# Does NOT apply regularization, which is left to the submitter to do in
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,20 @@ def init_model_fn(

def is_output_params(self, param_key: spec.ParameterKey) -> bool:
return param_key == 'Dense_0'

@property
def validation_target_value(self) -> float:
return 0.118232

@property
def test_target_value(self) -> float:
return 0.073397

@property
def step_hint(self) -> int:
"""Max num steps the baseline algo was given to reach the target."""
return 48_000

@property
def max_allowed_runtime_sec(self) -> int:
return 55_506 # ~15.4 hours
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,20 @@ def init_model_fn(

def is_output_params(self, param_key: spec.ParameterKey) -> bool:
return param_key in ['lin.weight', 'lin.bias']

@property
def validation_target_value(self) -> float:
return 0.118232

@property
def test_target_value(self) -> float:
return 0.073397

@property
def step_hint(self) -> int:
"""Max num steps the baseline algo was given to reach the target."""
return 48_000

@property
def max_allowed_runtime_sec(self) -> int:
return 55_506 # ~15.4 hours

This file was deleted.

23 changes: 11 additions & 12 deletions docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,12 @@

# To build Docker image
FROM nvidia/cuda:11.8.0-cudnn8-devel-ubuntu20.04
ARG DEBIAN_FRONTEND=noninteractive

# Installing machine packages
RUN echo "Setting up machine"
RUN apt-get update
RUN apt-get install -y curl tar
RUN apt-get install -y git python3 pip wget ffmpeg
RUN DEBIAN_FRONTEND=noninteractive apt-get install -y git python3 pip wget ffmpeg
RUN apt-get install libtcmalloc-minimal4
RUN apt-get install unzip
RUN apt-get install pigz
Expand All @@ -34,38 +33,38 @@ RUN echo "Setting up algorithmic_efficiency repo"
ARG branch="main"
ARG framework="both"
ARG git_url=https://github.com/mlcommons/algorithmic-efficiency.git
RUN git clone $git_url && cd algorithmic-efficiency
RUN cd algorithmic-efficiency && git checkout $branch
RUN git clone $git_url && cd /algorithmic-efficiency
RUN cd /algorithmic-efficiency && git checkout $branch

RUN cd algorithmic-efficiency && pip install -e '.[full]'
RUN cd /algorithmic-efficiency && pip install -e '.[full]'

RUN if [ "$framework" = "jax" ] ; then \
echo "Installing Jax GPU" \
&& cd algorithmic-efficiency \
&& cd /algorithmic-efficiency \
&& pip install -e '.[jax_gpu]' -f 'https://storage.googleapis.com/jax-releases/jax_cuda_releases.html' \
&& pip install -e '.[pytorch_cpu]' -f 'https://download.pytorch.org/whl/torch_stable.html'; \
elif [ "$framework" = "pytorch" ] ; then \
echo "Installing Pytorch GPU" \
&& cd algorithmic-efficiency \
&& cd /algorithmic-efficiency \
&& pip install -e '.[jax_cpu]' \
&& pip install -e '.[pytorch_gpu]' -f 'https://download.pytorch.org/whl/torch_stable.html'; \
elif [ "$framework" = "both" ] ; then \
echo "Installing Jax GPU and Pytorch GPU" \
&& cd algorithmic-efficiency \
&& cd /algorithmic-efficiency \
&& pip install -e '.[jax_gpu]' -f 'https://storage.googleapis.com/jax-releases/jax_cuda_releases.html' \
&& pip install -e '.[pytorch_gpu]' -f 'https://download.pytorch.org/whl/torch_stable.html'; \
else \
echo "Invalid build-arg $framework: framework should be either jax, pytorch or both." >&2 \
&& exit 1 ; \
fi

RUN cd algorithmic-efficiency && pip install -e '.[wandb]'
RUN cd /algorithmic-efficiency && pip install -e '.[wandb]'

RUN cd algorithmic-efficiency && git fetch origin
RUN cd algorithmic-efficiency && git pull
RUN cd /algorithmic-efficiency && git fetch origin
RUN cd /algorithmic-efficiency && git pull

# Todo: remove this, this is temporary for developing
COPY scripts/startup.sh /algorithmic-efficiency/docker/scripts/startup.sh
RUN chmod a+x /algorithmic-efficiency/docker/scripts/startup.sh

ENTRYPOINT ["bash", "algorithmic-efficiency/docker/scripts/startup.sh"]
ENTRYPOINT ["bash", "/algorithmic-efficiency/docker/scripts/startup.sh"]
4 changes: 2 additions & 2 deletions reference_algorithms/target_setting_algorithms/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ python3 submission_runner.py \
--experiment_dir=$ROOT_DIR \
--experiment_name=target_setting \
--workload=librispeech_conformer \
--submission_path=reference_algorithms/target_setting_algorithms/jax_nadamw.py \
--submission_path=reference_algorithms/target_setting_algorithms/jax_adamw.py \
--tuning_search_space=reference_algorithms/target_setting_algorithms/librispeech_conformer/tuning_search_space.json
```
```bash
Expand All @@ -123,7 +123,7 @@ torchrun --redirects 1:0,2:0,3:0,4:0,5:0,6:0,7:0 --standalone --nnodes=1 --nproc
--experiment_dir=$ROOT_DIR \
--experiment_name=target_setting \
--workload=librispeech_conformer \
--submission_path=reference_algorithms/target_setting_algorithms/pytorch_nadamw.py \
--submission_path=reference_algorithms/target_setting_algorithms/pytorch_adamw.py \
--tuning_search_space=reference_algorithms/target_setting_algorithms/librispeech_conformer/tuning_search_space.json
```

Expand Down
Original file line number Diff line number Diff line change
@@ -1,27 +1,27 @@
{
"learning_rate": {
"feasible_points": [
0.001308209823469072
0.002106913873888147
]
},
"beta1": {
"feasible_points": [
0.9731333693827139
0.8231189937738506
]
},
"beta2": {
"feasible_points": [
0.9981232922116359
0.8774571227688758
]
},
"warmup_steps": {
"feasible_points": [
9999
1199
]
},
"weight_decay": {
"feasible_points": [
0.16375311233774334
0.27590534177690645
]
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
},
"warmup_steps": {
"feasible_points": [
1200
720
]
},
"weight_decay": {
Expand Down
21 changes: 5 additions & 16 deletions scoring/scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@


def generate_eval_cols(metrics):
splits = ['train', 'validation', 'test']
splits = ['train', 'validation']
return [f'{split}/{col}' for split, col in itertools.product(splits, metrics)]


Expand Down Expand Up @@ -108,15 +108,13 @@ def get_index_that_reaches_best(workload_df, metric_col):

def get_index_that_reaches_target(workload_df,
validation_metric,
test_metric,
validation_target,
test_target):
validation_target):
"""Get the eval index in which a workload reaches the target metric_col.
Args:
workload_df: A subset of a submission's trials DataFrame that
includes only the trials in a single workload.
metric_col: Name of array column in workload_df (e.g., `validation/l1_loss`).
metric_col: Name of array column in workload_df (e.g. `validation/l1_loss`).
target: Target value for metric_col.
Returns:
Expand All @@ -125,20 +123,13 @@ def get_index_that_reaches_target(workload_df,
"""
is_minimized = check_if_minimized(validation_metric)
validation_series = workload_df[validation_metric]
test_series = workload_df[test_metric]

validation_series = validation_series[validation_series != np.nan]
validation_series = validation_series[test_series != np.nan]
test_series = test_series[validation_series != np.nan]
test_series = test_series[test_series != np.nan]

op = operator.le if is_minimized else operator.ge
validation_target_reached = validation_series.apply(
lambda x: op(x, validation_target))
test_target_reached = test_series.apply(lambda x: op(x, test_target))

target_reached = pd.Series(validation_target_reached[0]
& test_target_reached[0])
target_reached = pd.Series(validation_target_reached[0])
# Remove trials that never reach the target
target_reached = target_reached[target_reached.apply(np.any)]

Expand Down Expand Up @@ -188,12 +179,10 @@ def get_times_for_submission(submission,
workload_init_kwargs=workload_init_kwargs)
metric_name = workload_obj.target_metric_name
validation_metric = f'validation/{metric_name}'
test_metric = f'test/{metric_name}'
validation_target = workload_obj.validation_target_value
test_target = workload_obj.test_target_value

trial_idx, time_idx = get_index_that_reaches_target(
group, validation_metric, test_metric, validation_target, test_target)
group, validation_metric, validation_target)
if time_idx > -1:
time_val = group[time_col].loc[trial_idx][time_idx]
else:
Expand Down
29 changes: 25 additions & 4 deletions submission_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import datetime
import gc
import importlib
import itertools
import json
import os
import struct
Expand Down Expand Up @@ -133,6 +134,14 @@
flags.DEFINE_boolean('save_checkpoints',
True,
'Whether or not to checkpoint the model at every eval.')
flags.DEFINE_integer(
'hparam_start_index',
None,
'Start index to slice set of hyperparameters in tuning search space.')
flags.DEFINE_integer(
'hparam_end_index',
None,
'End index to slice set of hyperparameters in tuning spearch space.')
flags.DEFINE_integer(
'rng_seed',
None,
Expand Down Expand Up @@ -331,9 +340,12 @@ def train_once(

train_state['accumulated_submission_time'] += (
train_step_end_time - train_state['last_step_end_time'])
# Use 3x the runtime budget for the self-tuning ruleset.
max_allowed_runtime_sec = (
workload.max_allowed_runtime_sec if FLAGS.tuning_ruleset == 'external'
else 3 * workload.max_allowed_runtime_sec)
train_state['is_time_remaining'] = (
train_state['accumulated_submission_time'] <
workload.max_allowed_runtime_sec)
train_state['accumulated_submission_time'] < max_allowed_runtime_sec)
# Check if submission is eligible for an untimed eval.
if ((train_step_end_time - train_state['last_eval_time']) >=
workload.eval_period_time_sec or train_state['training_complete']):
Expand Down Expand Up @@ -455,6 +467,8 @@ def score_submission_on_workload(workload: spec.Workload,
num_tuning_trials: Optional[int] = None,
log_dir: Optional[str] = None,
save_checkpoints: Optional[bool] = True,
hparam_start_index: Optional[bool] = None,
hparam_end_index: Optional[bool] = None,
rng_seed: Optional[int] = None):
# Expand paths because '~' may not be recognized
data_dir = os.path.expanduser(data_dir)
Expand Down Expand Up @@ -500,7 +514,9 @@ def score_submission_on_workload(workload: spec.Workload,
json.load(search_space_file), num_tuning_trials)
all_timings = []
all_metrics = []
for hi, hyperparameters in enumerate(tuning_search_space):
tuning_search_space_iter = itertools.islice(
enumerate(tuning_search_space), hparam_start_index, hparam_end_index)
for hi, hyperparameters in tuning_search_space_iter:
# Generate a new seed from hardware sources of randomness for each trial.
if not rng_seed:
rng_seed = struct.unpack('I', os.urandom(4))[0]
Expand Down Expand Up @@ -545,7 +561,7 @@ def score_submission_on_workload(workload: spec.Workload,
all_timings.append(timing)
all_metrics.append(metrics)
score = min(all_timings)
for ti in range(num_tuning_trials):
for ti, _ in tuning_search_space_iter:
logging.info(f'Tuning trial {ti + 1}/{num_tuning_trials}')
logging.info(f'Hyperparameters: {tuning_search_space[ti]}')
logging.info(f'Metrics: {all_metrics[ti]}')
Expand All @@ -554,6 +570,9 @@ def score_submission_on_workload(workload: spec.Workload,
logging.info(f'Total number of evals: {num_evals}')
logging.info('=' * 20)
else:
if tuning_search_space is not None:
raise ValueError(
'Cannot provide a tuning search space when using self tuning.')
if not rng_seed:
rng_seed = struct.unpack('q', os.urandom(8))[0]
rng = prng.PRNGKey(rng_seed)
Expand Down Expand Up @@ -621,6 +640,8 @@ def main(_):
num_tuning_trials=FLAGS.num_tuning_trials,
log_dir=logging_dir_path,
save_checkpoints=FLAGS.save_checkpoints,
hparam_start_index=FLAGS.hparam_start_index,
hparam_end_index=FLAGS.hparam_end_index,
rng_seed=FLAGS.rng_seed)
logging.info(f'Final {FLAGS.workload} score: {score}')

Expand Down

0 comments on commit e19dacf

Please sign in to comment.