diff --git a/README.md b/README.md index a4536c35e..1be096c2e 100644 --- a/README.md +++ b/README.md @@ -33,7 +33,7 @@ ## Installation -You can install this package and dependences in a [python virtual environment](#virtual-environment) or use a [Docker container](#install-in-docker) (recommended). +You can install this package and dependences in a [python virtual environment](#virtual-environment) or use a [Docker/Singularity/Apptainer container](#install-in-docker) (recommended). *TL;DR to install the Jax version for GPU run:* @@ -89,7 +89,8 @@ pip3 install -e '.[full]' ## Docker -We recommend using a Docker container to ensure a similar environment to our scoring and testing environments. +We recommend using a Docker container to ensure a similar environment to our scoring and testing environments. +Alternatively, a Singularity/Apptainer container can also be used (see instructions below). **Prerequisites for NVIDIA GPU set up**: You may have to install the NVIDIA Container Toolkit so that the containers can locate the NVIDIA drivers and GPUs. @@ -133,6 +134,25 @@ To use the Docker container as an interactive virtual environment, you can run a ### Running Docker Container (End-to-end) To run a submission end-to-end in a containerized environment see [Getting Started Document](./getting_started.md#run-your-submission-in-a-docker-container). +### Using Singularity/Apptainer instead of Docker +Since many compute clusters don't allow the usage of Docker due to securtiy concerns and instead encourage the use of [Singularity/Apptainer](https://github.com/apptainer/apptainer) (formerly Singularity, now called Apptainer), we also provide instructions on how to build an Apptainer container based on the here provided Dockerfile. + +To convert the Dockerfile into an Apptainer definition file, we will use [spython](https://github.com/singularityhub/singularity-cli): +```bash +pip3 install spython +cd algorithmic-efficiency/docker +spython recipe Dockerfile &> Singularity.def +``` +Now we can build the Apptainer image by running +```bash +singularity build --fakeroot .sif Singularity.def +``` +To start a shell session with GPU support (by using the `--nv` flag), we can run +```bash +singularity shell --nv .sif +``` +Similarly to Docker, Apptainer allows you to bind specific paths on the host system and the container by specifying the `--bind` flag, as explained [here](https://docs.sylabs.io/guides/3.7/user-guide/bind_paths_and_mounts.html). + # Getting Started For instructions on developing and scoring your own algorithm in the benchmark see [Getting Started Document](./getting_started.md). ## Running a workload diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py index 45d77ede4..bc7eae3b8 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py @@ -144,7 +144,7 @@ def _build_input_queue( } padded_batch = data_utils.shard_and_maybe_pad_np( - numpy_batch, padding_value=1.0, global_batch_size=global_batch_size) + numpy_batch, padding_value=1.0) yield padded_batch # Does NOT apply regularization, which is left to the submitter to do in diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py index 4086a5841..cb07075f4 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py @@ -50,3 +50,20 @@ def init_model_fn( def is_output_params(self, param_key: spec.ParameterKey) -> bool: return param_key == 'Dense_0' + + @property + def validation_target_value(self) -> float: + return 0.118232 + + @property + def test_target_value(self) -> float: + return 0.073397 + + @property + def step_hint(self) -> int: + """Max num steps the baseline algo was given to reach the target.""" + return 48_000 + + @property + def max_allowed_runtime_sec(self) -> int: + return 55_506 # ~15.4 hours diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py index 1bb649ba8..09a4b0aa4 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py @@ -59,3 +59,20 @@ def init_model_fn( def is_output_params(self, param_key: spec.ParameterKey) -> bool: return param_key in ['lin.weight', 'lin.bias'] + + @property + def validation_target_value(self) -> float: + return 0.118232 + + @property + def test_target_value(self) -> float: + return 0.073397 + + @property + def step_hint(self) -> int: + """Max num steps the baseline algo was given to reach the target.""" + return 48_000 + + @property + def max_allowed_runtime_sec(self) -> int: + return 55_506 # ~15.4 hours diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/workload.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/workload.py deleted file mode 100644 index f9fd30b0d..000000000 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/workload.py +++ /dev/null @@ -1,21 +0,0 @@ -from algorithmic_efficiency.workloads.librispeech_conformer import workload - - -class BaseDeepspeechLibrispeechWorkload(workload.BaseLibrispeechWorkload): - - @property - def validation_target_value(self) -> float: - return 0.118232 - - @property - def test_target_value(self) -> float: - return 0.073397 - - @property - def step_hint(self) -> int: - """Max num steps the baseline algo was given to reach the target.""" - return 48_000 - - @property - def max_allowed_runtime_sec(self) -> int: - return 55_506 # ~15.4 hours diff --git a/docker/Dockerfile b/docker/Dockerfile index ab6a798c1..bc3b51649 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -6,13 +6,12 @@ # To build Docker image FROM nvidia/cuda:11.8.0-cudnn8-devel-ubuntu20.04 -ARG DEBIAN_FRONTEND=noninteractive # Installing machine packages RUN echo "Setting up machine" RUN apt-get update RUN apt-get install -y curl tar -RUN apt-get install -y git python3 pip wget ffmpeg +RUN DEBIAN_FRONTEND=noninteractive apt-get install -y git python3 pip wget ffmpeg RUN apt-get install libtcmalloc-minimal4 RUN apt-get install unzip RUN apt-get install pigz @@ -34,24 +33,24 @@ RUN echo "Setting up algorithmic_efficiency repo" ARG branch="main" ARG framework="both" ARG git_url=https://github.com/mlcommons/algorithmic-efficiency.git -RUN git clone $git_url && cd algorithmic-efficiency -RUN cd algorithmic-efficiency && git checkout $branch +RUN git clone $git_url && cd /algorithmic-efficiency +RUN cd /algorithmic-efficiency && git checkout $branch -RUN cd algorithmic-efficiency && pip install -e '.[full]' +RUN cd /algorithmic-efficiency && pip install -e '.[full]' RUN if [ "$framework" = "jax" ] ; then \ echo "Installing Jax GPU" \ - && cd algorithmic-efficiency \ + && cd /algorithmic-efficiency \ && pip install -e '.[jax_gpu]' -f 'https://storage.googleapis.com/jax-releases/jax_cuda_releases.html' \ && pip install -e '.[pytorch_cpu]' -f 'https://download.pytorch.org/whl/torch_stable.html'; \ elif [ "$framework" = "pytorch" ] ; then \ echo "Installing Pytorch GPU" \ - && cd algorithmic-efficiency \ + && cd /algorithmic-efficiency \ && pip install -e '.[jax_cpu]' \ && pip install -e '.[pytorch_gpu]' -f 'https://download.pytorch.org/whl/torch_stable.html'; \ elif [ "$framework" = "both" ] ; then \ echo "Installing Jax GPU and Pytorch GPU" \ - && cd algorithmic-efficiency \ + && cd /algorithmic-efficiency \ && pip install -e '.[jax_gpu]' -f 'https://storage.googleapis.com/jax-releases/jax_cuda_releases.html' \ && pip install -e '.[pytorch_gpu]' -f 'https://download.pytorch.org/whl/torch_stable.html'; \ else \ @@ -59,13 +58,13 @@ RUN if [ "$framework" = "jax" ] ; then \ && exit 1 ; \ fi -RUN cd algorithmic-efficiency && pip install -e '.[wandb]' +RUN cd /algorithmic-efficiency && pip install -e '.[wandb]' -RUN cd algorithmic-efficiency && git fetch origin -RUN cd algorithmic-efficiency && git pull +RUN cd /algorithmic-efficiency && git fetch origin +RUN cd /algorithmic-efficiency && git pull # Todo: remove this, this is temporary for developing COPY scripts/startup.sh /algorithmic-efficiency/docker/scripts/startup.sh RUN chmod a+x /algorithmic-efficiency/docker/scripts/startup.sh -ENTRYPOINT ["bash", "algorithmic-efficiency/docker/scripts/startup.sh"] +ENTRYPOINT ["bash", "/algorithmic-efficiency/docker/scripts/startup.sh"] diff --git a/reference_algorithms/target_setting_algorithms/README.md b/reference_algorithms/target_setting_algorithms/README.md index 117bbed3c..822907ba3 100644 --- a/reference_algorithms/target_setting_algorithms/README.md +++ b/reference_algorithms/target_setting_algorithms/README.md @@ -113,7 +113,7 @@ python3 submission_runner.py \ --experiment_dir=$ROOT_DIR \ --experiment_name=target_setting \ --workload=librispeech_conformer \ - --submission_path=reference_algorithms/target_setting_algorithms/jax_nadamw.py \ + --submission_path=reference_algorithms/target_setting_algorithms/jax_adamw.py \ --tuning_search_space=reference_algorithms/target_setting_algorithms/librispeech_conformer/tuning_search_space.json ``` ```bash @@ -123,7 +123,7 @@ torchrun --redirects 1:0,2:0,3:0,4:0,5:0,6:0,7:0 --standalone --nnodes=1 --nproc --experiment_dir=$ROOT_DIR \ --experiment_name=target_setting \ --workload=librispeech_conformer \ - --submission_path=reference_algorithms/target_setting_algorithms/pytorch_nadamw.py \ + --submission_path=reference_algorithms/target_setting_algorithms/pytorch_adamw.py \ --tuning_search_space=reference_algorithms/target_setting_algorithms/librispeech_conformer/tuning_search_space.json ``` diff --git a/reference_algorithms/target_setting_algorithms/librispeech_conformer/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/librispeech_conformer/tuning_search_space.json index 13bf07b4b..482a28931 100644 --- a/reference_algorithms/target_setting_algorithms/librispeech_conformer/tuning_search_space.json +++ b/reference_algorithms/target_setting_algorithms/librispeech_conformer/tuning_search_space.json @@ -1,27 +1,27 @@ { "learning_rate": { "feasible_points": [ - 0.001308209823469072 + 0.002106913873888147 ] }, "beta1": { "feasible_points": [ - 0.9731333693827139 + 0.8231189937738506 ] }, "beta2": { "feasible_points": [ - 0.9981232922116359 + 0.8774571227688758 ] }, "warmup_steps": { "feasible_points": [ - 9999 + 1199 ] }, "weight_decay": { "feasible_points": [ - 0.16375311233774334 + 0.27590534177690645 ] } } diff --git a/reference_algorithms/target_setting_algorithms/librispeech_deepspeech/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/librispeech_deepspeech/tuning_search_space.json index 106e124a0..0a9bfb3cf 100644 --- a/reference_algorithms/target_setting_algorithms/librispeech_deepspeech/tuning_search_space.json +++ b/reference_algorithms/target_setting_algorithms/librispeech_deepspeech/tuning_search_space.json @@ -16,7 +16,7 @@ }, "warmup_steps": { "feasible_points": [ - 1200 + 720 ] }, "weight_decay": { diff --git a/scoring/scoring.py b/scoring/scoring.py index fff152255..12aae1357 100644 --- a/scoring/scoring.py +++ b/scoring/scoring.py @@ -53,7 +53,7 @@ def generate_eval_cols(metrics): - splits = ['train', 'validation', 'test'] + splits = ['train', 'validation'] return [f'{split}/{col}' for split, col in itertools.product(splits, metrics)] @@ -108,15 +108,13 @@ def get_index_that_reaches_best(workload_df, metric_col): def get_index_that_reaches_target(workload_df, validation_metric, - test_metric, - validation_target, - test_target): + validation_target): """Get the eval index in which a workload reaches the target metric_col. Args: workload_df: A subset of a submission's trials DataFrame that includes only the trials in a single workload. - metric_col: Name of array column in workload_df (e.g., `validation/l1_loss`). + metric_col: Name of array column in workload_df (e.g. `validation/l1_loss`). target: Target value for metric_col. Returns: @@ -125,20 +123,13 @@ def get_index_that_reaches_target(workload_df, """ is_minimized = check_if_minimized(validation_metric) validation_series = workload_df[validation_metric] - test_series = workload_df[test_metric] - validation_series = validation_series[validation_series != np.nan] - validation_series = validation_series[test_series != np.nan] - test_series = test_series[validation_series != np.nan] - test_series = test_series[test_series != np.nan] op = operator.le if is_minimized else operator.ge validation_target_reached = validation_series.apply( lambda x: op(x, validation_target)) - test_target_reached = test_series.apply(lambda x: op(x, test_target)) - target_reached = pd.Series(validation_target_reached[0] - & test_target_reached[0]) + target_reached = pd.Series(validation_target_reached[0]) # Remove trials that never reach the target target_reached = target_reached[target_reached.apply(np.any)] @@ -188,12 +179,10 @@ def get_times_for_submission(submission, workload_init_kwargs=workload_init_kwargs) metric_name = workload_obj.target_metric_name validation_metric = f'validation/{metric_name}' - test_metric = f'test/{metric_name}' validation_target = workload_obj.validation_target_value - test_target = workload_obj.test_target_value trial_idx, time_idx = get_index_that_reaches_target( - group, validation_metric, test_metric, validation_target, test_target) + group, validation_metric, validation_target) if time_idx > -1: time_val = group[time_col].loc[trial_idx][time_idx] else: diff --git a/submission_runner.py b/submission_runner.py index 2289d39d3..47730d3fc 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -17,6 +17,7 @@ import datetime import gc import importlib +import itertools import json import os import struct @@ -133,6 +134,14 @@ flags.DEFINE_boolean('save_checkpoints', True, 'Whether or not to checkpoint the model at every eval.') +flags.DEFINE_integer( + 'hparam_start_index', + None, + 'Start index to slice set of hyperparameters in tuning search space.') +flags.DEFINE_integer( + 'hparam_end_index', + None, + 'End index to slice set of hyperparameters in tuning spearch space.') flags.DEFINE_integer( 'rng_seed', None, @@ -331,9 +340,12 @@ def train_once( train_state['accumulated_submission_time'] += ( train_step_end_time - train_state['last_step_end_time']) + # Use 3x the runtime budget for the self-tuning ruleset. + max_allowed_runtime_sec = ( + workload.max_allowed_runtime_sec if FLAGS.tuning_ruleset == 'external' + else 3 * workload.max_allowed_runtime_sec) train_state['is_time_remaining'] = ( - train_state['accumulated_submission_time'] < - workload.max_allowed_runtime_sec) + train_state['accumulated_submission_time'] < max_allowed_runtime_sec) # Check if submission is eligible for an untimed eval. if ((train_step_end_time - train_state['last_eval_time']) >= workload.eval_period_time_sec or train_state['training_complete']): @@ -455,6 +467,8 @@ def score_submission_on_workload(workload: spec.Workload, num_tuning_trials: Optional[int] = None, log_dir: Optional[str] = None, save_checkpoints: Optional[bool] = True, + hparam_start_index: Optional[bool] = None, + hparam_end_index: Optional[bool] = None, rng_seed: Optional[int] = None): # Expand paths because '~' may not be recognized data_dir = os.path.expanduser(data_dir) @@ -500,7 +514,9 @@ def score_submission_on_workload(workload: spec.Workload, json.load(search_space_file), num_tuning_trials) all_timings = [] all_metrics = [] - for hi, hyperparameters in enumerate(tuning_search_space): + tuning_search_space_iter = itertools.islice( + enumerate(tuning_search_space), hparam_start_index, hparam_end_index) + for hi, hyperparameters in tuning_search_space_iter: # Generate a new seed from hardware sources of randomness for each trial. if not rng_seed: rng_seed = struct.unpack('I', os.urandom(4))[0] @@ -545,7 +561,7 @@ def score_submission_on_workload(workload: spec.Workload, all_timings.append(timing) all_metrics.append(metrics) score = min(all_timings) - for ti in range(num_tuning_trials): + for ti, _ in tuning_search_space_iter: logging.info(f'Tuning trial {ti + 1}/{num_tuning_trials}') logging.info(f'Hyperparameters: {tuning_search_space[ti]}') logging.info(f'Metrics: {all_metrics[ti]}') @@ -554,6 +570,9 @@ def score_submission_on_workload(workload: spec.Workload, logging.info(f'Total number of evals: {num_evals}') logging.info('=' * 20) else: + if tuning_search_space is not None: + raise ValueError( + 'Cannot provide a tuning search space when using self tuning.') if not rng_seed: rng_seed = struct.unpack('q', os.urandom(8))[0] rng = prng.PRNGKey(rng_seed) @@ -621,6 +640,8 @@ def main(_): num_tuning_trials=FLAGS.num_tuning_trials, log_dir=logging_dir_path, save_checkpoints=FLAGS.save_checkpoints, + hparam_start_index=FLAGS.hparam_start_index, + hparam_end_index=FLAGS.hparam_end_index, rng_seed=FLAGS.rng_seed) logging.info(f'Final {FLAGS.workload} score: {score}')