From 234eb5aabf53b68afbe932cdd7059b3baf37760f Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 13 Sep 2023 00:05:55 +0000 Subject: [PATCH 01/55] add hparam index flags to submission_runner --- submission_runner.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/submission_runner.py b/submission_runner.py index f4ee32ede..0d7862ac5 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -17,6 +17,7 @@ import datetime import gc import importlib +import itertools import json import os import struct @@ -133,6 +134,8 @@ flags.DEFINE_boolean('save_checkpoints', True, 'Whether or not to checkpoint the model at every eval.') +flags.DEFINE_int('hparam_start_index', None, 'Start index for hyperparameter selection.') +flags.DEFINE_int('hparam_end_index', None, 'End index for hyperparameter selection.') FLAGS = flags.FLAGS USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_setup() @@ -449,7 +452,9 @@ def score_submission_on_workload(workload: spec.Workload, tuning_search_space: Optional[str] = None, num_tuning_trials: Optional[int] = None, log_dir: Optional[str] = None, - save_checkpoints: Optional[bool] = True): + save_checkpoints: Optional[bool] = True, + hparam_start_index: Optional[bool] = None, + hparam_end_index: Optional[bool] = None,): # Expand paths because '~' may not be recognized data_dir = os.path.expanduser(data_dir) if imagenet_v2_data_dir: @@ -494,7 +499,8 @@ def score_submission_on_workload(workload: spec.Workload, json.load(search_space_file), num_tuning_trials) all_timings = [] all_metrics = [] - for hi, hyperparameters in enumerate(tuning_search_space): + for hi, hyperparameters in itertools.islice(enumerate(tuning_search_space), + hparam_start_index, hparam_end_index): # Generate a new seed from hardware sources of randomness for each trial. rng_seed = struct.unpack('I', os.urandom(4))[0] logging.info('Using RNG seed %d', rng_seed) @@ -610,7 +616,9 @@ def main(_): tuning_search_space=FLAGS.tuning_search_space, num_tuning_trials=FLAGS.num_tuning_trials, log_dir=logging_dir_path, - save_checkpoints=FLAGS.save_checkpoints) + save_checkpoints=FLAGS.save_checkpoints, + hparam_start_index=FLAGS.hparam_start_index, + hparam_end_index=FLAGS.hparam_end_index) logging.info(f'Final {FLAGS.workload} score: {score}') if FLAGS.profile: From 7a0f0f9d286158b78e520aa12236bf2f800b0116 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 13 Sep 2023 00:19:37 +0000 Subject: [PATCH 02/55] fix --- submission_runner.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/submission_runner.py b/submission_runner.py index 0d7862ac5..31b81b431 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -134,8 +134,8 @@ flags.DEFINE_boolean('save_checkpoints', True, 'Whether or not to checkpoint the model at every eval.') -flags.DEFINE_int('hparam_start_index', None, 'Start index for hyperparameter selection.') -flags.DEFINE_int('hparam_end_index', None, 'End index for hyperparameter selection.') +flags.DEFINE_integer('hparam_start_index', None, 'Start index for hyperparameter selection.') +flags.DEFINE_integer('hparam_end_index', None, 'End index for hyperparameter selection.') FLAGS = flags.FLAGS USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_setup() From fe05acd5be7807d7f2e523df1249ef161dee2c7d Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 13 Sep 2023 00:24:36 +0000 Subject: [PATCH 03/55] clarify --- submission_runner.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/submission_runner.py b/submission_runner.py index 31b81b431..fe49d7410 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -134,8 +134,8 @@ flags.DEFINE_boolean('save_checkpoints', True, 'Whether or not to checkpoint the model at every eval.') -flags.DEFINE_integer('hparam_start_index', None, 'Start index for hyperparameter selection.') -flags.DEFINE_integer('hparam_end_index', None, 'End index for hyperparameter selection.') +flags.DEFINE_integer('hparam_start_index', None, 'Start index to slice set of hyperparameters in tuning search space.') +flags.DEFINE_integer('hparam_end_index', None, 'End index to slice set of hyperparameters in tuning spearch space.') FLAGS = flags.FLAGS USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_setup() From 77f44fb344d4791a31a5aede89ad0ff6e0ef9e8e Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 13 Sep 2023 00:26:16 +0000 Subject: [PATCH 04/55] reformatting --- submission_runner.py | 42 +++++++++++++++++++++++++----------------- 1 file changed, 25 insertions(+), 17 deletions(-) diff --git a/submission_runner.py b/submission_runner.py index fe49d7410..924025d68 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -134,8 +134,14 @@ flags.DEFINE_boolean('save_checkpoints', True, 'Whether or not to checkpoint the model at every eval.') -flags.DEFINE_integer('hparam_start_index', None, 'Start index to slice set of hyperparameters in tuning search space.') -flags.DEFINE_integer('hparam_end_index', None, 'End index to slice set of hyperparameters in tuning spearch space.') +flags.DEFINE_integer( + 'hparam_start_index', + None, + 'Start index to slice set of hyperparameters in tuning search space.') +flags.DEFINE_integer( + 'hparam_end_index', + None, + 'End index to slice set of hyperparameters in tuning spearch space.') FLAGS = flags.FLAGS USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_setup() @@ -441,20 +447,22 @@ def train_once( return train_state['accumulated_submission_time'], metrics -def score_submission_on_workload(workload: spec.Workload, - workload_name: str, - submission_path: str, - data_dir: str, - tuning_ruleset: str, - profiler: Optional[Profiler] = None, - max_global_steps: Optional[int] = None, - imagenet_v2_data_dir: Optional[str] = None, - tuning_search_space: Optional[str] = None, - num_tuning_trials: Optional[int] = None, - log_dir: Optional[str] = None, - save_checkpoints: Optional[bool] = True, - hparam_start_index: Optional[bool] = None, - hparam_end_index: Optional[bool] = None,): +def score_submission_on_workload( + workload: spec.Workload, + workload_name: str, + submission_path: str, + data_dir: str, + tuning_ruleset: str, + profiler: Optional[Profiler] = None, + max_global_steps: Optional[int] = None, + imagenet_v2_data_dir: Optional[str] = None, + tuning_search_space: Optional[str] = None, + num_tuning_trials: Optional[int] = None, + log_dir: Optional[str] = None, + save_checkpoints: Optional[bool] = True, + hparam_start_index: Optional[bool] = None, + hparam_end_index: Optional[bool] = None, +): # Expand paths because '~' may not be recognized data_dir = os.path.expanduser(data_dir) if imagenet_v2_data_dir: @@ -499,7 +507,7 @@ def score_submission_on_workload(workload: spec.Workload, json.load(search_space_file), num_tuning_trials) all_timings = [] all_metrics = [] - for hi, hyperparameters in itertools.islice(enumerate(tuning_search_space), + for hi, hyperparameters in itertools.islice(enumerate(tuning_search_space), hparam_start_index, hparam_end_index): # Generate a new seed from hardware sources of randomness for each trial. rng_seed = struct.unpack('I', os.urandom(4))[0] From 0100e442e9f2fdc466833fe2571400672724816f Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 13 Sep 2023 00:30:10 +0000 Subject: [PATCH 05/55] fix --- submission_runner.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/submission_runner.py b/submission_runner.py index 924025d68..580c87bd0 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -507,8 +507,9 @@ def score_submission_on_workload( json.load(search_space_file), num_tuning_trials) all_timings = [] all_metrics = [] - for hi, hyperparameters in itertools.islice(enumerate(tuning_search_space), - hparam_start_index, hparam_end_index): + tuning_search_space = itertools.islice( + enumerate(tuning_search_space), hparam_start_index, hparam_end_index) + for hi, hyperparameters in tuning_search_space: # Generate a new seed from hardware sources of randomness for each trial. rng_seed = struct.unpack('I', os.urandom(4))[0] logging.info('Using RNG seed %d', rng_seed) @@ -550,7 +551,7 @@ def score_submission_on_workload( all_timings.append(timing) all_metrics.append(metrics) score = min(all_timings) - for ti in range(num_tuning_trials): + for ti, _ in tuning_search_space: logging.info(f'Tuning trial {ti + 1}/{num_tuning_trials}') logging.info(f'Hyperparameters: {tuning_search_space[ti]}') logging.info(f'Metrics: {all_metrics[ti]}') From 3c5882b278d3f2dc05d545507ea623e5f52c73c7 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 13 Sep 2023 00:32:57 +0000 Subject: [PATCH 06/55] fix --- submission_runner.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/submission_runner.py b/submission_runner.py index 580c87bd0..4fa596b60 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -507,9 +507,9 @@ def score_submission_on_workload( json.load(search_space_file), num_tuning_trials) all_timings = [] all_metrics = [] - tuning_search_space = itertools.islice( + tuning_search_space_iter = itertools.islice( enumerate(tuning_search_space), hparam_start_index, hparam_end_index) - for hi, hyperparameters in tuning_search_space: + for hi, hyperparameters in tuning_search_space_iter: # Generate a new seed from hardware sources of randomness for each trial. rng_seed = struct.unpack('I', os.urandom(4))[0] logging.info('Using RNG seed %d', rng_seed) @@ -551,7 +551,7 @@ def score_submission_on_workload( all_timings.append(timing) all_metrics.append(metrics) score = min(all_timings) - for ti, _ in tuning_search_space: + for ti, _ in tuning_search_space_iter: logging.info(f'Tuning trial {ti + 1}/{num_tuning_trials}') logging.info(f'Hyperparameters: {tuning_search_space[ti]}') logging.info(f'Metrics: {all_metrics[ti]}') From 301de4ab3f93eeafe3149652be6f36a554163b88 Mon Sep 17 00:00:00 2001 From: runame Date: Mon, 18 Sep 2023 16:15:54 +0000 Subject: [PATCH 07/55] Switch to absolute paths in Dockerfile --- docker/Dockerfile | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/docker/Dockerfile b/docker/Dockerfile index ab6a798c1..dceee80ca 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -34,24 +34,24 @@ RUN echo "Setting up algorithmic_efficiency repo" ARG branch="main" ARG framework="both" ARG git_url=https://github.com/mlcommons/algorithmic-efficiency.git -RUN git clone $git_url && cd algorithmic-efficiency -RUN cd algorithmic-efficiency && git checkout $branch +RUN git clone $git_url && cd /algorithmic-efficiency +RUN cd /algorithmic-efficiency && git checkout $branch -RUN cd algorithmic-efficiency && pip install -e '.[full]' +RUN cd /algorithmic-efficiency && pip install -e '.[full]' RUN if [ "$framework" = "jax" ] ; then \ echo "Installing Jax GPU" \ - && cd algorithmic-efficiency \ + && cd /algorithmic-efficiency \ && pip install -e '.[jax_gpu]' -f 'https://storage.googleapis.com/jax-releases/jax_cuda_releases.html' \ && pip install -e '.[pytorch_cpu]' -f 'https://download.pytorch.org/whl/torch_stable.html'; \ elif [ "$framework" = "pytorch" ] ; then \ echo "Installing Pytorch GPU" \ - && cd algorithmic-efficiency \ + && cd /algorithmic-efficiency \ && pip install -e '.[jax_cpu]' \ && pip install -e '.[pytorch_gpu]' -f 'https://download.pytorch.org/whl/torch_stable.html'; \ elif [ "$framework" = "both" ] ; then \ echo "Installing Jax GPU and Pytorch GPU" \ - && cd algorithmic-efficiency \ + && cd /algorithmic-efficiency \ && pip install -e '.[jax_gpu]' -f 'https://storage.googleapis.com/jax-releases/jax_cuda_releases.html' \ && pip install -e '.[pytorch_gpu]' -f 'https://download.pytorch.org/whl/torch_stable.html'; \ else \ @@ -59,13 +59,13 @@ RUN if [ "$framework" = "jax" ] ; then \ && exit 1 ; \ fi -RUN cd algorithmic-efficiency && pip install -e '.[wandb]' +RUN cd /algorithmic-efficiency && pip install -e '.[wandb]' -RUN cd algorithmic-efficiency && git fetch origin -RUN cd algorithmic-efficiency && git pull +RUN cd /algorithmic-efficiency && git fetch origin +RUN cd /algorithmic-efficiency && git pull # Todo: remove this, this is temporary for developing COPY scripts/startup.sh /algorithmic-efficiency/docker/scripts/startup.sh RUN chmod a+x /algorithmic-efficiency/docker/scripts/startup.sh -ENTRYPOINT ["bash", "algorithmic-efficiency/docker/scripts/startup.sh"] +ENTRYPOINT ["bash", "/algorithmic-efficiency/docker/scripts/startup.sh"] From cc6d0dbdf86c1f3a0c6273f14ec26608f21ad025 Mon Sep 17 00:00:00 2001 From: runame Date: Mon, 18 Sep 2023 16:23:59 +0000 Subject: [PATCH 08/55] Only set DEBIAN_FRONTEND where necessary --- docker/Dockerfile | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/docker/Dockerfile b/docker/Dockerfile index dceee80ca..bc3b51649 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -6,13 +6,12 @@ # To build Docker image FROM nvidia/cuda:11.8.0-cudnn8-devel-ubuntu20.04 -ARG DEBIAN_FRONTEND=noninteractive # Installing machine packages RUN echo "Setting up machine" RUN apt-get update RUN apt-get install -y curl tar -RUN apt-get install -y git python3 pip wget ffmpeg +RUN DEBIAN_FRONTEND=noninteractive apt-get install -y git python3 pip wget ffmpeg RUN apt-get install libtcmalloc-minimal4 RUN apt-get install unzip RUN apt-get install pigz From e7b854c3116cfa58fd81f7f33d8891cd7bd9cdfa Mon Sep 17 00:00:00 2001 From: runame Date: Mon, 18 Sep 2023 19:44:20 +0200 Subject: [PATCH 09/55] Add instructions for running Singularity/Apptainer container to README --- README.md | 24 ++++++++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index a4536c35e..1be096c2e 100644 --- a/README.md +++ b/README.md @@ -33,7 +33,7 @@ ## Installation -You can install this package and dependences in a [python virtual environment](#virtual-environment) or use a [Docker container](#install-in-docker) (recommended). +You can install this package and dependences in a [python virtual environment](#virtual-environment) or use a [Docker/Singularity/Apptainer container](#install-in-docker) (recommended). *TL;DR to install the Jax version for GPU run:* @@ -89,7 +89,8 @@ pip3 install -e '.[full]' ## Docker -We recommend using a Docker container to ensure a similar environment to our scoring and testing environments. +We recommend using a Docker container to ensure a similar environment to our scoring and testing environments. +Alternatively, a Singularity/Apptainer container can also be used (see instructions below). **Prerequisites for NVIDIA GPU set up**: You may have to install the NVIDIA Container Toolkit so that the containers can locate the NVIDIA drivers and GPUs. @@ -133,6 +134,25 @@ To use the Docker container as an interactive virtual environment, you can run a ### Running Docker Container (End-to-end) To run a submission end-to-end in a containerized environment see [Getting Started Document](./getting_started.md#run-your-submission-in-a-docker-container). +### Using Singularity/Apptainer instead of Docker +Since many compute clusters don't allow the usage of Docker due to securtiy concerns and instead encourage the use of [Singularity/Apptainer](https://github.com/apptainer/apptainer) (formerly Singularity, now called Apptainer), we also provide instructions on how to build an Apptainer container based on the here provided Dockerfile. + +To convert the Dockerfile into an Apptainer definition file, we will use [spython](https://github.com/singularityhub/singularity-cli): +```bash +pip3 install spython +cd algorithmic-efficiency/docker +spython recipe Dockerfile &> Singularity.def +``` +Now we can build the Apptainer image by running +```bash +singularity build --fakeroot .sif Singularity.def +``` +To start a shell session with GPU support (by using the `--nv` flag), we can run +```bash +singularity shell --nv .sif +``` +Similarly to Docker, Apptainer allows you to bind specific paths on the host system and the container by specifying the `--bind` flag, as explained [here](https://docs.sylabs.io/guides/3.7/user-guide/bind_paths_and_mounts.html). + # Getting Started For instructions on developing and scoring your own algorithm in the benchmark see [Getting Started Document](./getting_started.md). ## Running a workload From 7970388e739aca7612672ed3f931ca8e007605b5 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Mon, 25 Sep 2023 05:31:19 +0000 Subject: [PATCH 10/55] update tuning search spaces for speech_workloads --- .../librispeech_conformer/tuning_search_space.json | 10 +++++----- .../librispeech_deepspeech/tuning_search_space.json | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/reference_algorithms/target_setting_algorithms/librispeech_conformer/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/librispeech_conformer/tuning_search_space.json index 13bf07b4b..482a28931 100644 --- a/reference_algorithms/target_setting_algorithms/librispeech_conformer/tuning_search_space.json +++ b/reference_algorithms/target_setting_algorithms/librispeech_conformer/tuning_search_space.json @@ -1,27 +1,27 @@ { "learning_rate": { "feasible_points": [ - 0.001308209823469072 + 0.002106913873888147 ] }, "beta1": { "feasible_points": [ - 0.9731333693827139 + 0.8231189937738506 ] }, "beta2": { "feasible_points": [ - 0.9981232922116359 + 0.8774571227688758 ] }, "warmup_steps": { "feasible_points": [ - 9999 + 1199 ] }, "weight_decay": { "feasible_points": [ - 0.16375311233774334 + 0.27590534177690645 ] } } diff --git a/reference_algorithms/target_setting_algorithms/librispeech_deepspeech/tuning_search_space.json b/reference_algorithms/target_setting_algorithms/librispeech_deepspeech/tuning_search_space.json index 106e124a0..0a9bfb3cf 100644 --- a/reference_algorithms/target_setting_algorithms/librispeech_deepspeech/tuning_search_space.json +++ b/reference_algorithms/target_setting_algorithms/librispeech_deepspeech/tuning_search_space.json @@ -16,7 +16,7 @@ }, "warmup_steps": { "feasible_points": [ - 1200 + 720 ] }, "weight_decay": { From 862f5009b571d64ec3a808e05689be7b079cbbc3 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Tue, 26 Sep 2023 22:55:51 +0000 Subject: [PATCH 11/55] add tabulate for deepspeech debugging --- .../librispeech_deepspeech/librispeech_jax/workload.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py index 4086a5841..54da37f0e 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py @@ -5,6 +5,8 @@ import jax import jax.numpy as jnp import numpy as np +import flax.linen as nn +from absl import logging from algorithmic_efficiency import param_utils from algorithmic_efficiency import spec @@ -34,6 +36,11 @@ def init_model_fn( input_shape = [(320000,), (320000,)] fake_input_batch = [np.zeros((2, *x), jnp.float32) for x in input_shape] + tabulate_fn = nn.tabulate(self._model, jax.random.PRNGKey(0), + console_kwargs={'force_terminal': False, + 'force_jupyter': False, + 'width': 240}) + logging.info(tabuleate_fn(*fake_input_batch), train=False) model_init_fn = jax.jit(functools.partial(self._model.init, train=False)) params_rng, dropout_rng = jax.random.split(rng, 2) From 181623494a1b3c617cf1d7f77203b392e4561959 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Tue, 26 Sep 2023 23:00:43 +0000 Subject: [PATCH 12/55] update target setting algo for conformer to adamw --- reference_algorithms/target_setting_algorithms/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/reference_algorithms/target_setting_algorithms/README.md b/reference_algorithms/target_setting_algorithms/README.md index 117bbed3c..822907ba3 100644 --- a/reference_algorithms/target_setting_algorithms/README.md +++ b/reference_algorithms/target_setting_algorithms/README.md @@ -113,7 +113,7 @@ python3 submission_runner.py \ --experiment_dir=$ROOT_DIR \ --experiment_name=target_setting \ --workload=librispeech_conformer \ - --submission_path=reference_algorithms/target_setting_algorithms/jax_nadamw.py \ + --submission_path=reference_algorithms/target_setting_algorithms/jax_adamw.py \ --tuning_search_space=reference_algorithms/target_setting_algorithms/librispeech_conformer/tuning_search_space.json ``` ```bash @@ -123,7 +123,7 @@ torchrun --redirects 1:0,2:0,3:0,4:0,5:0,6:0,7:0 --standalone --nnodes=1 --nproc --experiment_dir=$ROOT_DIR \ --experiment_name=target_setting \ --workload=librispeech_conformer \ - --submission_path=reference_algorithms/target_setting_algorithms/pytorch_nadamw.py \ + --submission_path=reference_algorithms/target_setting_algorithms/pytorch_adamw.py \ --tuning_search_space=reference_algorithms/target_setting_algorithms/librispeech_conformer/tuning_search_space.json ``` From 6b9655f9edc96d421b971f15febe0eaeba7a57a2 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Tue, 26 Sep 2023 23:07:13 +0000 Subject: [PATCH 13/55] tabulate typo fix --- .../librispeech_deepspeech/librispeech_jax/workload.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py index 54da37f0e..46c524381 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py @@ -40,7 +40,7 @@ def init_model_fn( console_kwargs={'force_terminal': False, 'force_jupyter': False, 'width': 240}) - logging.info(tabuleate_fn(*fake_input_batch), train=False) + logging.info(tabulate_fn(*fake_input_batch), train=False) model_init_fn = jax.jit(functools.partial(self._model.init, train=False)) params_rng, dropout_rng = jax.random.split(rng, 2) From 28db3923487300d6d20fc1ef7893442710ccd6a5 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Tue, 26 Sep 2023 23:08:46 +0000 Subject: [PATCH 14/55] typo --- .../librispeech_deepspeech/librispeech_jax/workload.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py index 46c524381..1d2c48dbc 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py @@ -40,7 +40,7 @@ def init_model_fn( console_kwargs={'force_terminal': False, 'force_jupyter': False, 'width': 240}) - logging.info(tabulate_fn(*fake_input_batch), train=False) + logging.info(tabulate_fn(*fake_input_batch, train=False)) model_init_fn = jax.jit(functools.partial(self._model.init, train=False)) params_rng, dropout_rng = jax.random.split(rng, 2) From fab70f9c2e878566c7fe2f2fd03de525ea14719e Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Wed, 27 Sep 2023 21:00:25 +0000 Subject: [PATCH 15/55] add lr logging target_setting nadamw --- algorithmic_efficiency/checkpoint_utils.py | 2 +- reference_algorithms/target_setting_algorithms/jax_nadamw.py | 2 +- .../target_setting_algorithms/jax_submission_base.py | 5 +++-- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/algorithmic_efficiency/checkpoint_utils.py b/algorithmic_efficiency/checkpoint_utils.py index fb7449b99..5914654aa 100644 --- a/algorithmic_efficiency/checkpoint_utils.py +++ b/algorithmic_efficiency/checkpoint_utils.py @@ -62,7 +62,7 @@ def maybe_restore_checkpoint(framework: str, train_state, eval_results, global_step, preemption_count). """ if framework == 'jax': - opt_state, opt_update_fn = optimizer_state + opt_state, opt_update_fn, lr_schedule_fn = optimizer_state else: opt_state, opt_update_fn = optimizer_state, None diff --git a/reference_algorithms/target_setting_algorithms/jax_nadamw.py b/reference_algorithms/target_setting_algorithms/jax_nadamw.py index 21f2a7b2b..ec339b8b7 100644 --- a/reference_algorithms/target_setting_algorithms/jax_nadamw.py +++ b/reference_algorithms/target_setting_algorithms/jax_nadamw.py @@ -168,4 +168,4 @@ def init_optimizer_state(workload: spec.Workload, weight_decay=hyperparameters.weight_decay) optimizer_state = opt_init_fn(params_zeros_like) - return jax_utils.replicate(optimizer_state), opt_update_fn + return jax_utils.replicate(optimizer_state), opt_update_fn, lr_schedule_fn diff --git a/reference_algorithms/target_setting_algorithms/jax_submission_base.py b/reference_algorithms/target_setting_algorithms/jax_submission_base.py index 2a641b520..172b9edda 100644 --- a/reference_algorithms/target_setting_algorithms/jax_submission_base.py +++ b/reference_algorithms/target_setting_algorithms/jax_submission_base.py @@ -85,7 +85,7 @@ def update_params(workload: spec.Workload, del loss_type del eval_results - optimizer_state, opt_update_fn = optimizer_state + optimizer_state, opt_update_fn, lr_schedule_fn = optimizer_state per_device_rngs = jax.random.split(rng, jax.local_device_count()) if hasattr(hyperparameters, 'label_smoothing'): label_smoothing = hyperparameters.label_smoothing @@ -107,5 +107,6 @@ def update_params(workload: spec.Workload, { 'loss': loss[0], 'grad_norm': grad_norm[0], + 'learning_rate': lr_schedule_fn(global_step) }, global_step) - return (new_optimizer_state, opt_update_fn), new_params, new_model_state + return (new_optimizer_state, opt_update_fn, lr_schedule_fn), new_params, new_model_state From 1ea08b0e2ad760a4483f1583faa4749c8dd56b35 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 28 Sep 2023 01:54:00 +0000 Subject: [PATCH 16/55] reverse padding fixes --- algorithmic_efficiency/data_utils.py | 31 ++++++++++++++-------------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/algorithmic_efficiency/data_utils.py b/algorithmic_efficiency/data_utils.py index 14e3c7c6c..38744716b 100644 --- a/algorithmic_efficiency/data_utils.py +++ b/algorithmic_efficiency/data_utils.py @@ -28,15 +28,8 @@ def shard_and_maybe_pad_np( inputs = batch['inputs'] current_batch_size = inputs[0].shape[0] if isinstance( inputs, tuple) else inputs.shape[0] - if global_batch_size is not None: - assert global_batch_size >= current_batch_size, \ - 'global_batch_size must be larger than or equal to current_batch_size.' - # Always pad to global_batch_size if it is provided. - pad_to_global_batch_size = global_batch_size > current_batch_size - else: - pad_to_global_batch_size = False remainder_size = current_batch_size % local_device_count - if remainder_size != 0 or pad_to_global_batch_size: + if remainder_size != 0: if global_batch_size is not None: pad_size = global_batch_size - current_batch_size else: @@ -57,8 +50,8 @@ def _prepare(x): x = x._numpy() # pylint: disable=protected-access # Pad if remainder_size != 0 (should only be possible during evaluation). - if remainder_size != 0 or pad_to_global_batch_size: - x = pad(x, pad_size, padding_value=padding_value) + if remainder_size != 0: + x = pad(x, pad_size, 'jax', padding_value=padding_value) # Reshape (global_batch_size, ...) to # (local_device_count, per_device_batch_size, ...). @@ -68,13 +61,21 @@ def _prepare(x): return jax.tree_map(_prepare, batch) -def pad(tensor: np.ndarray, +def pad(tensor: spec.Tensor, pad_size: int, - padding_value: int = 0) -> np.ndarray: - if tensor.ndim > 1: + framework: str, + padding_value: int = 0) -> spec.Tensor: + if len(tensor) > 1: pad_size = (pad_size, *tensor.shape[1:]) - padding = np.full(pad_size, padding_value, dtype=tensor.dtype) - padded_tensor = np.concatenate((tensor, padding), axis=0) + if framework == 'pytorch': + padding = torch.full( + pad_size, padding_value, dtype=tensor.dtype, device=tensor.device) + padded_tensor = torch.cat((tensor, padding), dim=0) + elif framework == 'jax': + padding = np.full(pad_size, padding_value, dtype=tensor.dtype) + padded_tensor = np.concatenate((tensor, padding), axis=0) + else: + raise ValueError(f'Framework has to be pytorch or jax, but is {framework}.') return padded_tensor From c3cf664f7cd27c0bd3315481e4700236f627faed Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 28 Sep 2023 21:15:42 +0000 Subject: [PATCH 17/55] log_step_hint --- reference_algorithms/target_setting_algorithms/jax_nadamw.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/reference_algorithms/target_setting_algorithms/jax_nadamw.py b/reference_algorithms/target_setting_algorithms/jax_nadamw.py index ec339b8b7..0b64741f2 100644 --- a/reference_algorithms/target_setting_algorithms/jax_nadamw.py +++ b/reference_algorithms/target_setting_algorithms/jax_nadamw.py @@ -7,6 +7,7 @@ import jax import jax.numpy as jnp import optax +from absl import logging from algorithmic_efficiency import spec from reference_algorithms.target_setting_algorithms import cosine_warmup @@ -152,6 +153,7 @@ def init_optimizer_state(workload: spec.Workload, del rng target_setting_step_hint = int(0.75 * workload.step_hint) + logging.info(f'target setting step hint: {target_setting_step_hint}') lr_schedule_fn = cosine_warmup.jax_cosine_warmup(target_setting_step_hint, hyperparameters) @@ -168,4 +170,4 @@ def init_optimizer_state(workload: spec.Workload, weight_decay=hyperparameters.weight_decay) optimizer_state = opt_init_fn(params_zeros_like) - return jax_utils.replicate(optimizer_state), opt_update_fn, lr_schedule_fn + return jax_utils.replicate(optimizer_state), opt_update_fn, opt_update_fn From 784230f8a48596ab1241e4f052f83b653236cc4e Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 28 Sep 2023 21:19:15 +0000 Subject: [PATCH 18/55] log step hint in submission runner --- submission_runner.py | 1 + 1 file changed, 1 insertion(+) diff --git a/submission_runner.py b/submission_runner.py index 2289d39d3..3d852518b 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -530,6 +530,7 @@ def score_submission_on_workload(workload: spec.Workload, with profiler.profile('Train'): if 'imagenet' not in workload_name: imagenet_v2_data_dir = None + logging.info(f"Workload step hint: {workload.step_hint}") timing, metrics = train_once(workload, global_batch_size, global_eval_batch_size, data_dir, imagenet_v2_data_dir, From e83937b8fec7f8c8f635caec6acaa1928bdfa506 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 28 Sep 2023 21:21:23 +0000 Subject: [PATCH 19/55] add logging --- submission_runner.py | 1 + 1 file changed, 1 insertion(+) diff --git a/submission_runner.py b/submission_runner.py index 3d852518b..1350afbfe 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -531,6 +531,7 @@ def score_submission_on_workload(workload: spec.Workload, if 'imagenet' not in workload_name: imagenet_v2_data_dir = None logging.info(f"Workload step hint: {workload.step_hint}") + logging.info(f'Workload setting step hint: {target_setting_step_hint}') timing, metrics = train_once(workload, global_batch_size, global_eval_batch_size, data_dir, imagenet_v2_data_dir, From 0606b0174e557b10e213f68e4d1b5a900a90adf1 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 28 Sep 2023 21:30:09 +0000 Subject: [PATCH 20/55] lgoging --- submission_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/submission_runner.py b/submission_runner.py index 1350afbfe..3a0f699a1 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -531,7 +531,6 @@ def score_submission_on_workload(workload: spec.Workload, if 'imagenet' not in workload_name: imagenet_v2_data_dir = None logging.info(f"Workload step hint: {workload.step_hint}") - logging.info(f'Workload setting step hint: {target_setting_step_hint}') timing, metrics = train_once(workload, global_batch_size, global_eval_batch_size, data_dir, imagenet_v2_data_dir, @@ -599,6 +598,7 @@ def main(_): workload_path=workload_metadata['workload_path'], workload_class_name=workload_metadata['workload_class_name'], workload_init_kwargs=workload_init_kwargs) + logging.info(f'workload : {workload_path} {workload_class_name}') experiment_name = FLAGS.experiment_name if experiment_name and FLAGS.append_timestamp: From b23aae4a41a1ecae6b8d9a9acaf91b696ec284d5 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 28 Sep 2023 21:33:46 +0000 Subject: [PATCH 21/55] logging --- submission_runner.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/submission_runner.py b/submission_runner.py index 3a0f699a1..6f741680e 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -598,7 +598,10 @@ def main(_): workload_path=workload_metadata['workload_path'], workload_class_name=workload_metadata['workload_class_name'], workload_init_kwargs=workload_init_kwargs) - logging.info(f'workload : {workload_path} {workload_class_name}') + workload_path = workload_metadata['workload_path'] + workload_class = workload_metadata['workload_class'] + logging.info(f'workload : {workload_path}') + logging.info(f'workload class: {workload_class}') experiment_name = FLAGS.experiment_name if experiment_name and FLAGS.append_timestamp: From 27cb0379cf235a2ef78308d9fc41a109f5a7b647 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 28 Sep 2023 21:35:43 +0000 Subject: [PATCH 22/55] fix logging --- submission_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/submission_runner.py b/submission_runner.py index 6f741680e..ce726b59c 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -599,7 +599,7 @@ def main(_): workload_class_name=workload_metadata['workload_class_name'], workload_init_kwargs=workload_init_kwargs) workload_path = workload_metadata['workload_path'] - workload_class = workload_metadata['workload_class'] + workload_class = workload_metadata['workload_class_name'] logging.info(f'workload : {workload_path}') logging.info(f'workload class: {workload_class}') From a0c86248eb3770ac521376eeb95ef4be450860f1 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 28 Sep 2023 21:45:41 +0000 Subject: [PATCH 23/55] more logging; --- submission_runner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/submission_runner.py b/submission_runner.py index ce726b59c..84e29eafc 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -602,7 +602,7 @@ def main(_): workload_class = workload_metadata['workload_class_name'] logging.info(f'workload : {workload_path}') logging.info(f'workload class: {workload_class}') - + logging.info(f'workload max run time: {workload.max_allowed_runtime_sec}') experiment_name = FLAGS.experiment_name if experiment_name and FLAGS.append_timestamp: experiment_name += datetime.datetime.now().strftime('-%Y-%m-%d-%H-%M-%S') From d977605f8a9524fca6a0f666797e4eaca3ab975e Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 28 Sep 2023 21:50:21 +0000 Subject: [PATCH 24/55] remove inheritance --- .../librispeech_deepspeech/workload.py | 46 ++++++++++++++++++- 1 file changed, 45 insertions(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/workload.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/workload.py index f9fd30b0d..6c157a60e 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/workload.py @@ -1,7 +1,7 @@ from algorithmic_efficiency.workloads.librispeech_conformer import workload -class BaseDeepspeechLibrispeechWorkload(workload.BaseLibrispeechWorkload): +class BaseDeepspeechLibrispeechWorkload(spec.Workload): @property def validation_target_value(self) -> float: @@ -11,6 +11,50 @@ def validation_target_value(self) -> float: def test_target_value(self) -> float: return 0.073397 + def has_reached_test_target(self, eval_result: Dict[str, float]) -> bool: + return eval_result['test/wer'] < self.test_target_value + + @property + def loss_type(self) -> spec.LossType: + return spec.LossType.CTC_LOSS + + @property + def num_train_examples(self) -> int: + return 263840 + + @property + def num_eval_train_examples(self) -> int: + # Round up from num_validation_examples (which is the default for + # num_eval_train_examples) to the next multiple of eval_batch_size, so that + # we don't have to extract the correctly sized subset of the training data. + rounded_up_multiple = math.ceil(self.num_validation_examples / + self.eval_batch_size) + return rounded_up_multiple * self.eval_batch_size + + @property + def num_validation_examples(self) -> int: + return 5348 + + @property + def num_test_examples(self) -> int: + return 2472 + + @property + def eval_batch_size(self) -> int: + return 256 + + @property + def train_mean(self): + raise NotImplementedError + + @property + def train_stddev(self): + raise NotImplementedError + + @property + def eval_period_time_sec(self) -> int: + return 24 * 60 + @property def step_hint(self) -> int: """Max num steps the baseline algo was given to reach the target.""" From 9749f00acda23187fe398b29080b8d4dfb4b6c7c Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 28 Sep 2023 21:55:33 +0000 Subject: [PATCH 25/55] temp change to conformer --- .../workloads/librispeech_conformer/workload.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/workload.py b/algorithmic_efficiency/workloads/librispeech_conformer/workload.py index dc7fb912b..c8a6d0066 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/workload.py @@ -76,4 +76,4 @@ def eval_period_time_sec(self) -> int: @property def step_hint(self) -> int: """Max num steps the baseline algo was given to reach the target.""" - return 80_000 + return 60_000 From 869bdadb44e184010fe47a3d3699eeac9eb30c26 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Fri, 29 Sep 2023 05:15:19 +0000 Subject: [PATCH 26/55] fix step hint --- .../workloads/librispeech_conformer/workload.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/workload.py b/algorithmic_efficiency/workloads/librispeech_conformer/workload.py index c8a6d0066..dc7fb912b 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/workload.py @@ -76,4 +76,4 @@ def eval_period_time_sec(self) -> int: @property def step_hint(self) -> int: """Max num steps the baseline algo was given to reach the target.""" - return 60_000 + return 80_000 From 0a60f89a56b06fdd3307e474e049a21f20e601f4 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Fri, 29 Sep 2023 20:31:55 +0000 Subject: [PATCH 27/55] deepspeech inheritance fix --- .../librispeech_jax/workload.py | 2 +- .../librispeech_pytorch/workload.py | 2 +- .../librispeech_deepspeech/workload.py | 46 +------------------ 3 files changed, 3 insertions(+), 47 deletions(-) diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py index 1d2c48dbc..884c8f00f 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py @@ -16,7 +16,7 @@ models -class LibriSpeechDeepSpeechWorkload(LibriSpeechConformerWorkload): +class LibriSpeechDeepSpeechWorkload(BaseDeepspeechLibrispeechWorkload): def init_model_fn( self, diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py index 1bb649ba8..b2a1d8ea1 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py @@ -20,7 +20,7 @@ MAX_INPUT_LENGTH = 320000 -class LibriSpeechDeepSpeechWorkload(LibriSpeechConformerWorkload): +class LibriSpeechDeepSpeechWorkload(BaseDeepspeechLibrispeechWorkload): def init_model_fn( self, diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/workload.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/workload.py index 6c157a60e..63253ff34 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/workload.py @@ -9,51 +9,7 @@ def validation_target_value(self) -> float: @property def test_target_value(self) -> float: - return 0.073397 - - def has_reached_test_target(self, eval_result: Dict[str, float]) -> bool: - return eval_result['test/wer'] < self.test_target_value - - @property - def loss_type(self) -> spec.LossType: - return spec.LossType.CTC_LOSS - - @property - def num_train_examples(self) -> int: - return 263840 - - @property - def num_eval_train_examples(self) -> int: - # Round up from num_validation_examples (which is the default for - # num_eval_train_examples) to the next multiple of eval_batch_size, so that - # we don't have to extract the correctly sized subset of the training data. - rounded_up_multiple = math.ceil(self.num_validation_examples / - self.eval_batch_size) - return rounded_up_multiple * self.eval_batch_size - - @property - def num_validation_examples(self) -> int: - return 5348 - - @property - def num_test_examples(self) -> int: - return 2472 - - @property - def eval_batch_size(self) -> int: - return 256 - - @property - def train_mean(self): - raise NotImplementedError - - @property - def train_stddev(self): - raise NotImplementedError - - @property - def eval_period_time_sec(self) -> int: - return 24 * 60 + return 0.07339 @property def step_hint(self) -> int: From 06df206afa39518f6e426c39614bc492a8784c7f Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Fri, 29 Sep 2023 20:38:40 +0000 Subject: [PATCH 28/55] remove lr schedule logging --- algorithmic_efficiency/checkpoint_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/checkpoint_utils.py b/algorithmic_efficiency/checkpoint_utils.py index 5914654aa..fb7449b99 100644 --- a/algorithmic_efficiency/checkpoint_utils.py +++ b/algorithmic_efficiency/checkpoint_utils.py @@ -62,7 +62,7 @@ def maybe_restore_checkpoint(framework: str, train_state, eval_results, global_step, preemption_count). """ if framework == 'jax': - opt_state, opt_update_fn, lr_schedule_fn = optimizer_state + opt_state, opt_update_fn = optimizer_state else: opt_state, opt_update_fn = optimizer_state, None From aa6c538c84466660420c573390fc613a00994163 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Fri, 29 Sep 2023 20:45:42 +0000 Subject: [PATCH 29/55] debugging statements --- algorithmic_efficiency/data_utils.py | 31 +++++++++---------- .../librispeech_jax/workload.py | 7 ----- .../librispeech_deepspeech/workload.py | 2 +- .../target_setting_algorithms/jax_nadamw.py | 4 +-- .../jax_submission_base.py | 5 ++- submission_runner.py | 7 +---- 6 files changed, 20 insertions(+), 36 deletions(-) diff --git a/algorithmic_efficiency/data_utils.py b/algorithmic_efficiency/data_utils.py index 38744716b..14e3c7c6c 100644 --- a/algorithmic_efficiency/data_utils.py +++ b/algorithmic_efficiency/data_utils.py @@ -28,8 +28,15 @@ def shard_and_maybe_pad_np( inputs = batch['inputs'] current_batch_size = inputs[0].shape[0] if isinstance( inputs, tuple) else inputs.shape[0] + if global_batch_size is not None: + assert global_batch_size >= current_batch_size, \ + 'global_batch_size must be larger than or equal to current_batch_size.' + # Always pad to global_batch_size if it is provided. + pad_to_global_batch_size = global_batch_size > current_batch_size + else: + pad_to_global_batch_size = False remainder_size = current_batch_size % local_device_count - if remainder_size != 0: + if remainder_size != 0 or pad_to_global_batch_size: if global_batch_size is not None: pad_size = global_batch_size - current_batch_size else: @@ -50,8 +57,8 @@ def _prepare(x): x = x._numpy() # pylint: disable=protected-access # Pad if remainder_size != 0 (should only be possible during evaluation). - if remainder_size != 0: - x = pad(x, pad_size, 'jax', padding_value=padding_value) + if remainder_size != 0 or pad_to_global_batch_size: + x = pad(x, pad_size, padding_value=padding_value) # Reshape (global_batch_size, ...) to # (local_device_count, per_device_batch_size, ...). @@ -61,21 +68,13 @@ def _prepare(x): return jax.tree_map(_prepare, batch) -def pad(tensor: spec.Tensor, +def pad(tensor: np.ndarray, pad_size: int, - framework: str, - padding_value: int = 0) -> spec.Tensor: - if len(tensor) > 1: + padding_value: int = 0) -> np.ndarray: + if tensor.ndim > 1: pad_size = (pad_size, *tensor.shape[1:]) - if framework == 'pytorch': - padding = torch.full( - pad_size, padding_value, dtype=tensor.dtype, device=tensor.device) - padded_tensor = torch.cat((tensor, padding), dim=0) - elif framework == 'jax': - padding = np.full(pad_size, padding_value, dtype=tensor.dtype) - padded_tensor = np.concatenate((tensor, padding), axis=0) - else: - raise ValueError(f'Framework has to be pytorch or jax, but is {framework}.') + padding = np.full(pad_size, padding_value, dtype=tensor.dtype) + padded_tensor = np.concatenate((tensor, padding), axis=0) return padded_tensor diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py index 884c8f00f..74dfe2f97 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py @@ -5,8 +5,6 @@ import jax import jax.numpy as jnp import numpy as np -import flax.linen as nn -from absl import logging from algorithmic_efficiency import param_utils from algorithmic_efficiency import spec @@ -36,11 +34,6 @@ def init_model_fn( input_shape = [(320000,), (320000,)] fake_input_batch = [np.zeros((2, *x), jnp.float32) for x in input_shape] - tabulate_fn = nn.tabulate(self._model, jax.random.PRNGKey(0), - console_kwargs={'force_terminal': False, - 'force_jupyter': False, - 'width': 240}) - logging.info(tabulate_fn(*fake_input_batch, train=False)) model_init_fn = jax.jit(functools.partial(self._model.init, train=False)) params_rng, dropout_rng = jax.random.split(rng, 2) diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/workload.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/workload.py index 63253ff34..875179119 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/workload.py @@ -9,7 +9,7 @@ def validation_target_value(self) -> float: @property def test_target_value(self) -> float: - return 0.07339 + return 0.073397 @property def step_hint(self) -> int: diff --git a/reference_algorithms/target_setting_algorithms/jax_nadamw.py b/reference_algorithms/target_setting_algorithms/jax_nadamw.py index 0b64741f2..21f2a7b2b 100644 --- a/reference_algorithms/target_setting_algorithms/jax_nadamw.py +++ b/reference_algorithms/target_setting_algorithms/jax_nadamw.py @@ -7,7 +7,6 @@ import jax import jax.numpy as jnp import optax -from absl import logging from algorithmic_efficiency import spec from reference_algorithms.target_setting_algorithms import cosine_warmup @@ -153,7 +152,6 @@ def init_optimizer_state(workload: spec.Workload, del rng target_setting_step_hint = int(0.75 * workload.step_hint) - logging.info(f'target setting step hint: {target_setting_step_hint}') lr_schedule_fn = cosine_warmup.jax_cosine_warmup(target_setting_step_hint, hyperparameters) @@ -170,4 +168,4 @@ def init_optimizer_state(workload: spec.Workload, weight_decay=hyperparameters.weight_decay) optimizer_state = opt_init_fn(params_zeros_like) - return jax_utils.replicate(optimizer_state), opt_update_fn, opt_update_fn + return jax_utils.replicate(optimizer_state), opt_update_fn diff --git a/reference_algorithms/target_setting_algorithms/jax_submission_base.py b/reference_algorithms/target_setting_algorithms/jax_submission_base.py index 172b9edda..2a641b520 100644 --- a/reference_algorithms/target_setting_algorithms/jax_submission_base.py +++ b/reference_algorithms/target_setting_algorithms/jax_submission_base.py @@ -85,7 +85,7 @@ def update_params(workload: spec.Workload, del loss_type del eval_results - optimizer_state, opt_update_fn, lr_schedule_fn = optimizer_state + optimizer_state, opt_update_fn = optimizer_state per_device_rngs = jax.random.split(rng, jax.local_device_count()) if hasattr(hyperparameters, 'label_smoothing'): label_smoothing = hyperparameters.label_smoothing @@ -107,6 +107,5 @@ def update_params(workload: spec.Workload, { 'loss': loss[0], 'grad_norm': grad_norm[0], - 'learning_rate': lr_schedule_fn(global_step) }, global_step) - return (new_optimizer_state, opt_update_fn, lr_schedule_fn), new_params, new_model_state + return (new_optimizer_state, opt_update_fn), new_params, new_model_state diff --git a/submission_runner.py b/submission_runner.py index 84e29eafc..2289d39d3 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -530,7 +530,6 @@ def score_submission_on_workload(workload: spec.Workload, with profiler.profile('Train'): if 'imagenet' not in workload_name: imagenet_v2_data_dir = None - logging.info(f"Workload step hint: {workload.step_hint}") timing, metrics = train_once(workload, global_batch_size, global_eval_batch_size, data_dir, imagenet_v2_data_dir, @@ -598,11 +597,7 @@ def main(_): workload_path=workload_metadata['workload_path'], workload_class_name=workload_metadata['workload_class_name'], workload_init_kwargs=workload_init_kwargs) - workload_path = workload_metadata['workload_path'] - workload_class = workload_metadata['workload_class_name'] - logging.info(f'workload : {workload_path}') - logging.info(f'workload class: {workload_class}') - logging.info(f'workload max run time: {workload.max_allowed_runtime_sec}') + experiment_name = FLAGS.experiment_name if experiment_name and FLAGS.append_timestamp: experiment_name += datetime.datetime.now().strftime('-%Y-%m-%d-%H-%M-%S') From 3afcc9f4a82545159f1b5b0ee34d5a372f972fc1 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Fri, 29 Sep 2023 20:48:28 +0000 Subject: [PATCH 30/55] fix --- .../workloads/librispeech_deepspeech/workload.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/workload.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/workload.py index 875179119..f9fd30b0d 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/workload.py @@ -1,7 +1,7 @@ from algorithmic_efficiency.workloads.librispeech_conformer import workload -class BaseDeepspeechLibrispeechWorkload(spec.Workload): +class BaseDeepspeechLibrispeechWorkload(workload.BaseLibrispeechWorkload): @property def validation_target_value(self) -> float: From 3c72358535048ae7d846740efe5fe61fd7ec93fa Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Fri, 29 Sep 2023 21:01:13 +0000 Subject: [PATCH 31/55] fix --- .../workloads/librispeech_deepspeech/workload.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/workload.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/workload.py index f9fd30b0d..5488e72b2 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/workload.py @@ -1,4 +1,4 @@ -from algorithmic_efficiency.workloads.librispeech_conformer import workload +from algorithmic_efficiency.workloads.librispeech_deepspeech import workload class BaseDeepspeechLibrispeechWorkload(workload.BaseLibrispeechWorkload): From bc55c3423b8775242ebb0f2034868c88e624e43f Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Fri, 29 Sep 2023 23:38:34 +0000 Subject: [PATCH 32/55] fix imports --- .../librispeech_deepspeech/librispeech_jax/workload.py | 4 ++-- .../librispeech_deepspeech/librispeech_pytorch/workload.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py index 74dfe2f97..a8a9bd3d6 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py @@ -8,8 +8,8 @@ from algorithmic_efficiency import param_utils from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_jax.workload import \ - LibriSpeechConformerWorkload +from algorithmic_efficiency.workloads.librispeech_deepspeech.workload import \ + BaseDeepspeechLibrispeechWorkload from algorithmic_efficiency.workloads.librispeech_deepspeech.librispeech_jax import \ models diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py index b2a1d8ea1..7a7ccffaf 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py @@ -8,8 +8,8 @@ from algorithmic_efficiency.pytorch_utils import pytorch_setup from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_pytorch.models import \ initialize -from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_pytorch.workload import \ - LibriSpeechConformerWorkload +from algorithmic_efficiency.workloads.librispeech_deepspeech.workload import \ + BaseDeepspeechLibrispeechWorkload from algorithmic_efficiency.workloads.librispeech_deepspeech.librispeech_pytorch.models import \ DeepspeechConfig from algorithmic_efficiency.workloads.librispeech_deepspeech.librispeech_pytorch.models import \ From 1323b57a9d5f7ce28d9abce43867475872cc7993 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Fri, 29 Sep 2023 23:42:38 +0000 Subject: [PATCH 33/55] import fix --- .../librispeech_deepspeech/librispeech_jax/workload.py | 4 ++-- .../librispeech_deepspeech/librispeech_pytorch/workload.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py index a8a9bd3d6..dd34e6d97 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py @@ -8,10 +8,10 @@ from algorithmic_efficiency import param_utils from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.librispeech_deepspeech.workload import \ - BaseDeepspeechLibrispeechWorkload from algorithmic_efficiency.workloads.librispeech_deepspeech.librispeech_jax import \ models +from algorithmic_efficiency.workloads.librispeech_deepspeech.workload import \ + BaseDeepspeechLibrispeechWorkload class LibriSpeechDeepSpeechWorkload(BaseDeepspeechLibrispeechWorkload): diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py index 7a7ccffaf..94eb636c9 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py @@ -8,12 +8,12 @@ from algorithmic_efficiency.pytorch_utils import pytorch_setup from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_pytorch.models import \ initialize -from algorithmic_efficiency.workloads.librispeech_deepspeech.workload import \ - BaseDeepspeechLibrispeechWorkload from algorithmic_efficiency.workloads.librispeech_deepspeech.librispeech_pytorch.models import \ DeepspeechConfig from algorithmic_efficiency.workloads.librispeech_deepspeech.librispeech_pytorch.models import \ DeepspeechEncoderDecoder +from algorithmic_efficiency.workloads.librispeech_deepspeech.workload import \ + BaseDeepspeechLibrispeechWorkload USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_setup() From 0c67481cc0b1043bccf7be5b94f0b7280a18801e Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Sat, 30 Sep 2023 00:17:06 +0000 Subject: [PATCH 34/55] fix --- .../workloads/librispeech_deepspeech/workload.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/workload.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/workload.py index 5488e72b2..f9fd30b0d 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/workload.py @@ -1,4 +1,4 @@ -from algorithmic_efficiency.workloads.librispeech_deepspeech import workload +from algorithmic_efficiency.workloads.librispeech_conformer import workload class BaseDeepspeechLibrispeechWorkload(workload.BaseLibrispeechWorkload): From a76083bb2e2df0687537ed57d63ff0fc75c5055c Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Sat, 30 Sep 2023 00:21:30 +0000 Subject: [PATCH 35/55] fix --- submission_runner.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/submission_runner.py b/submission_runner.py index f77c471e3..d72062cf3 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -143,6 +143,8 @@ None, 'End index to slice set of hyperparameters in tuning spearch space.') 'rng_seed', +flags.DEFINE_integer( + 'rng_seed', None, 'Value of rng seed. If None, a random seed will' 'be generated from hardware.') From 88e9d9ed6483410d10b162096c3ea70537967238 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Sat, 30 Sep 2023 00:24:50 +0000 Subject: [PATCH 36/55] fix --- submission_runner.py | 35 ++++++++++++++++------------------- 1 file changed, 16 insertions(+), 19 deletions(-) diff --git a/submission_runner.py b/submission_runner.py index d72062cf3..717ea2dc4 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -142,7 +142,6 @@ 'hparam_end_index', None, 'End index to slice set of hyperparameters in tuning spearch space.') - 'rng_seed', flags.DEFINE_integer( 'rng_seed', None, @@ -453,23 +452,21 @@ def train_once( return train_state['accumulated_submission_time'], metrics -def score_submission_on_workload( - workload: spec.Workload, - workload_name: str, - submission_path: str, - data_dir: str, - tuning_ruleset: str, - profiler: Optional[Profiler] = None, - max_global_steps: Optional[int] = None, - imagenet_v2_data_dir: Optional[str] = None, - tuning_search_space: Optional[str] = None, - num_tuning_trials: Optional[int] = None, - log_dir: Optional[str] = None, - save_checkpoints: Optional[bool] = True, - hparam_start_index: Optional[bool] = None, - hparam_end_index: Optional[bool] = None, - rng_seed: Optional[int] = None - ): +def score_submission_on_workload(workload: spec.Workload, + workload_name: str, + submission_path: str, + data_dir: str, + tuning_ruleset: str, + profiler: Optional[Profiler] = None, + max_global_steps: Optional[int] = None, + imagenet_v2_data_dir: Optional[str] = None, + tuning_search_space: Optional[str] = None, + num_tuning_trials: Optional[int] = None, + log_dir: Optional[str] = None, + save_checkpoints: Optional[bool] = True, + hparam_start_index: Optional[bool] = None, + hparam_end_index: Optional[bool] = None, + rng_seed: Optional[int] = None): # Expand paths because '~' may not be recognized data_dir = os.path.expanduser(data_dir) if imagenet_v2_data_dir: @@ -638,7 +635,7 @@ def main(_): log_dir=logging_dir_path, save_checkpoints=FLAGS.save_checkpoints, hparam_start_index=FLAGS.hparam_start_index, - hparam_end_index=FLAGS.hparam_end_index) + hparam_end_index=FLAGS.hparam_end_index, rng_seed=FLAGS.rng_seed) logging.info(f'Final {FLAGS.workload} score: {score}') From 1e4434914fcb964ae4c95b19fc348d5f5b9752fd Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Sat, 30 Sep 2023 00:36:56 +0000 Subject: [PATCH 37/55] copy jax and pytorch loss_fn, model_fn and _eval_model_on_split to deepspeech classes --- .../librispeech_jax/workload.py | 105 ++++++++++++- .../librispeech_pytorch/workload.py | 140 ++++++++++++++++++ .../librispeech_deepspeech/workload.py | 2 +- 3 files changed, 245 insertions(+), 2 deletions(-) diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py index dd34e6d97..99bd5f0ea 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py @@ -14,7 +14,7 @@ BaseDeepspeechLibrispeechWorkload -class LibriSpeechDeepSpeechWorkload(BaseDeepspeechLibrispeechWorkload): +class LibriSpeechDeepSpeechWorkload(LibrispeechWorkload): def init_model_fn( self, @@ -50,3 +50,106 @@ def init_model_fn( def is_output_params(self, param_key: spec.ParameterKey) -> bool: return param_key == 'Dense_0' + + def model_fn( + self, + params: spec.ParameterContainer, + augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + mode: spec.ForwardPassMode, + rng: spec.RandomState, + update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + variables = {'params': params, **model_state} + inputs, input_paddings = augmented_and_preprocessed_input_batch['inputs'] + is_train_mode = mode == spec.ForwardPassMode.TRAIN + if update_batch_norm or is_train_mode: + (logits, logit_paddings), new_model_state = self._model.apply( + variables, + inputs, + input_paddings, + train=True, + rngs={'dropout' : rng}, + mutable=['batch_stats']) + return (logits, logit_paddings), new_model_state + else: + logits, logit_paddings = self._model.apply( + variables, + inputs, + input_paddings, + train=False, + mutable=False) + return (logits, logit_paddings), model_state + + # Does NOT apply regularization, which is left to the submitter to do in + # `update_params`. + def loss_fn( + self, + label_batch: Tuple[spec.Tensor, spec.Tensor], # (label_batch, padding) + logits_batch: Tuple[spec.Tensor, spec.Tensor], # (logits_batch, padding) + mask_batch: Optional[spec.Tensor] = None, + label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: # differentiable + """Evaluate the (masked) loss function at (label_batch, logits_batch). + + Return {'summed': scalar summed loss, 'n_valid_examples': scalar number of + valid examples in batch, 'per_example': 1-d array of per-example losses} + (not synced across devices). + """ + del label_smoothing + logits, logit_paddings = logits_batch + targets, target_paddings = label_batch + logprobs = nn.log_softmax(logits) + per_example_losses = self.ctc_loss(logprobs, + logit_paddings, + targets, + target_paddings) + # mask_batch is assumed to be shape [batch]. + if mask_batch is not None: + per_example_losses *= mask_batch + mask_batch = jnp.logical_and(mask_batch, 1 - target_paddings) + else: + mask_batch = 1 - target_paddings + n_valid_examples = jnp.maximum(mask_batch.sum(), 1) + summed_loss = per_example_losses.sum() + return { + 'summed': summed_loss, + 'n_valid_examples': n_valid_examples, + 'per_example': per_example_losses, + } + + def _eval_model_on_split(self, + split: str, + num_examples: int, + global_batch_size: int, + params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState, + data_dir: str, + global_step: int = 0) -> Dict[str, float]: + """Run a full evaluation of the model.""" + del global_step + if model_state is not None: + # Sync batch statistics across replicas before evaluating. + model_state = self.sync_batch_stats(model_state) + + num_batches = int(math.ceil(num_examples / global_batch_size)) + if split not in self._eval_iters: + self._eval_iters[split] = self._build_input_queue( + rng, split, data_dir, global_batch_size, num_batches=num_batches) + + metrics_report = None + for _ in range(num_batches): + eval_batch = next(self._eval_iters[split]) + computed_metrics = self.eval_step_pmapped(params, + eval_batch, + model_state, + rng).unreplicate() + + if metrics_report is None: + metrics_report = computed_metrics + else: + # `merge` aggregates the metrics across batches. + metrics_report = metrics_report.merge(computed_metrics) + + computed_metrics = metrics_report.compute() + + return computed_metrics \ No newline at end of file diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py index 94eb636c9..ee447fd9d 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py @@ -59,3 +59,143 @@ def init_model_fn( def is_output_params(self, param_key: spec.ParameterKey) -> bool: return param_key in ['lin.weight', 'lin.bias'] + + def model_fn( + self, + params: spec.ParameterContainer, + augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + mode: spec.ForwardPassMode, + rng: spec.RandomState, + update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + del model_state + del rng + + model = params + if mode == spec.ForwardPassMode.EVAL: + model.eval() + if mode == spec.ForwardPassMode.TRAIN: + model.train() + model.apply( + functools.partial( + pytorch_utils.update_batch_norm_fn, + update_batch_norm=update_batch_norm)) + + contexts = { + spec.ForwardPassMode.EVAL: torch.no_grad, + spec.ForwardPassMode.TRAIN: contextlib.nullcontext, + } + with contexts[mode](): + inputs, input_paddings = augmented_and_preprocessed_input_batch['inputs'] + logits, logits_paddings = model(inputs.to(DEVICE), + input_paddings.to(DEVICE)) + return (logits, logits_paddings), None + + # Does NOT apply regularization, which is left to the submitter to do in + # `update_params`. + def loss_fn( + self, + label_batch: Tuple[spec.Tensor, spec.Tensor], # (label_batch, padding) + logits_batch: Tuple[spec.Tensor, spec.Tensor], # (logits_batch, padding) + mask_batch: Optional[spec.Tensor] = None, + label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: # differentiable + """Evaluate the (masked) loss function at (label_batch, logits_batch). + + Return {'summed': scalar summed loss, 'n_valid_examples': scalar number of + valid examples in batch, 'per_example': 1-d array of per-example losses} + (not synced across devices). + """ + del label_smoothing + targets, target_paddings = label_batch + logits, logit_paddings = logits_batch + logprobs = torch.log_softmax(logits, dim=-1) + input_lengths = torch.einsum('bh->b', 1 - logit_paddings).long() + target_lengths = torch.einsum('bh->b', 1 - target_paddings).long() + per_example_losses = self.ctc_loss( + logprobs.permute(1, 0, 2), + targets.long(), + input_lengths, + target_lengths) + # mask_batch is assumed to be shape [batch]. + if mask_batch is not None: + per_example_losses *= mask_batch + mask_batch = torch.logical_and(mask_batch, target_lengths) + else: + mask_batch = target_lengths + n_valid_examples = mask_batch.sum().to(per_example_losses) + summed_loss = per_example_losses.sum() + n_valid_examples = max(n_valid_examples, 1) + return { + 'summed': summed_loss, + 'n_valid_examples': torch.as_tensor(n_valid_examples, device=DEVICE), + 'per_example': per_example_losses, + } + + def _eval_model_on_split(self, + split: str, + num_examples: int, + global_batch_size: int, + params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + rng: spec.RandomState, + data_dir: str, + global_step: int = 0) -> Dict[str, float]: + """Run a full evaluation of the model.""" + del global_step + data_rng, model_rng = prng.split(rng, 2) + if split not in self._eval_iters: + # These iterators repeat indefinitely. + self._eval_iters[split] = ( + self._build_input_queue( + data_rng, split, data_dir, global_batch_size=global_batch_size)) + + total_metrics = { + 'loss': torch.tensor(0., device=DEVICE), + 'lengths': torch.tensor(0., device=DEVICE), + 'word_errors': torch.tensor(0., device=DEVICE), + 'num_words': torch.tensor(0., device=DEVICE), + } + num_batches = int(math.ceil(num_examples / global_batch_size)) + if self.requires_sync_before_eval: + self.sync_sd(params) + for _ in range(num_batches): + batch = next(self._eval_iters[split]) + + (logits, logits_padding), _ = self.model_fn( + params, + batch, + model_state, + spec.ForwardPassMode.EVAL, + model_rng, + update_batch_norm=False) + decoded, decoded_paddings = self.greedy_decode(logits, logits_padding) + targets, target_paddings = batch['targets'] + word_errors, num_words = metrics.compute_wer( + decoded=decoded.cpu().numpy(), + decoded_paddings=decoded_paddings.cpu().numpy(), + targets=targets.cpu().numpy(), + target_paddings=target_paddings.cpu().numpy(), + tokenizer=self.tokenizer) + loss = self.loss_fn((targets, target_paddings), (logits, logits_padding)) + summed_loss = loss['summed'] + lengths = loss['n_valid_examples'] + batch_metrics = { + 'loss': summed_loss, + 'lengths': lengths, + 'word_errors': word_errors, + 'num_words': num_words, + } + total_metrics = { + k: v + batch_metrics[k] for k, v in total_metrics.items() + } + if USE_PYTORCH_DDP: + for metric in total_metrics.values(): + dist.all_reduce(metric) + return { + 'ctc_loss': + float(total_metrics['loss'].item() / + total_metrics['lengths'].item()), + 'wer': + float(total_metrics['word_errors'].item() / + total_metrics['num_words'].item()), + } diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/workload.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/workload.py index f9fd30b0d..0ac3bf422 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/workload.py @@ -1,7 +1,7 @@ from algorithmic_efficiency.workloads.librispeech_conformer import workload -class BaseDeepspeechLibrispeechWorkload(workload.BaseLibrispeechWorkload): +class BaseDeepspeechLibrispeechWorkload(workload.LibrispeechConformerWorkload): @property def validation_target_value(self) -> float: From 347d1df641b2f82d36e62d4fdce897dea0181eaf Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Sat, 30 Sep 2023 00:50:17 +0000 Subject: [PATCH 38/55] fix imports --- .../librispeech_deepspeech/librispeech_jax/workload.py | 3 ++- .../librispeech_pytorch/workload.py | 9 ++++++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py index 99bd5f0ea..4a0234664 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py @@ -1,5 +1,6 @@ import functools -from typing import Optional +import math +from typing import Dict, Iterator, Optional, Tuple from flax import jax_utils import jax diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py index ee447fd9d..a28d93312 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py @@ -1,11 +1,18 @@ -from typing import Optional +import contextlib +import functools +import math +from typing import Dict, Iterator, Optional, Tuple import torch +import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP from algorithmic_efficiency import param_utils +from algorithmic_efficiency import pytorch_utils from algorithmic_efficiency import spec from algorithmic_efficiency.pytorch_utils import pytorch_setup +import algorithmic_efficiency.random_utils as prng +from algorithmic_efficiency.workloads.librispeech_conformer import metrics from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_pytorch.models import \ initialize from algorithmic_efficiency.workloads.librispeech_deepspeech.librispeech_pytorch.models import \ From f7789db63a4e6af0e3e38b6e41c1bdc095e40d79 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Sat, 30 Sep 2023 00:56:13 +0000 Subject: [PATCH 39/55] fix block --- .../librispeech_deepspeech/librispeech_jax/workload.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py index 4a0234664..4fa2e9395 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py @@ -117,7 +117,7 @@ def loss_fn( 'per_example': per_example_losses, } - def _eval_model_on_split(self, + def _eval_model_on_split(self, split: str, num_examples: int, global_batch_size: int, From 711a8fed949d0462a4b90cece9be94b2a96077d2 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Mon, 2 Oct 2023 18:46:14 +0000 Subject: [PATCH 40/55] fix and formatting --- .../librispeech_jax/workload.py | 58 +++++++++---------- .../librispeech_pytorch/workload.py | 6 +- 2 files changed, 32 insertions(+), 32 deletions(-) diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py index 4fa2e9395..1ef38de82 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py @@ -15,7 +15,7 @@ BaseDeepspeechLibrispeechWorkload -class LibriSpeechDeepSpeechWorkload(LibrispeechWorkload): +class LibriSpeechDeepSpeechWorkload(BaseDeepspeechLibrispeechWorkload): def init_model_fn( self, @@ -53,33 +53,33 @@ def is_output_params(self, param_key: spec.ParameterKey) -> bool: return param_key == 'Dense_0' def model_fn( - self, - params: spec.ParameterContainer, - augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], - model_state: spec.ModelAuxiliaryState, - mode: spec.ForwardPassMode, - rng: spec.RandomState, - update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: - variables = {'params': params, **model_state} - inputs, input_paddings = augmented_and_preprocessed_input_batch['inputs'] - is_train_mode = mode == spec.ForwardPassMode.TRAIN - if update_batch_norm or is_train_mode: - (logits, logit_paddings), new_model_state = self._model.apply( - variables, - inputs, - input_paddings, - train=True, - rngs={'dropout' : rng}, - mutable=['batch_stats']) - return (logits, logit_paddings), new_model_state - else: - logits, logit_paddings = self._model.apply( - variables, - inputs, - input_paddings, - train=False, - mutable=False) - return (logits, logit_paddings), model_state + self, + params: spec.ParameterContainer, + augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + mode: spec.ForwardPassMode, + rng: spec.RandomState, + update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + variables = {'params': params, **model_state} + inputs, input_paddings = augmented_and_preprocessed_input_batch['inputs'] + is_train_mode = mode == spec.ForwardPassMode.TRAIN + if update_batch_norm or is_train_mode: + (logits, logit_paddings), new_model_state = self._model.apply( + variables, + inputs, + input_paddings, + train=True, + rngs={'dropout' : rng}, + mutable=['batch_stats']) + return (logits, logit_paddings), new_model_state + else: + logits, logit_paddings = self._model.apply( + variables, + inputs, + input_paddings, + train=False, + mutable=False) + return (logits, logit_paddings), model_state # Does NOT apply regularization, which is left to the submitter to do in # `update_params`. @@ -153,4 +153,4 @@ def _eval_model_on_split(self, computed_metrics = metrics_report.compute() - return computed_metrics \ No newline at end of file + return computed_metrics diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py index a28d93312..52663b854 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py @@ -66,7 +66,7 @@ def init_model_fn( def is_output_params(self, param_key: spec.ParameterKey) -> bool: return param_key in ['lin.weight', 'lin.bias'] - + def model_fn( self, params: spec.ParameterContainer, @@ -97,7 +97,7 @@ def model_fn( logits, logits_paddings = model(inputs.to(DEVICE), input_paddings.to(DEVICE)) return (logits, logits_paddings), None - + # Does NOT apply regularization, which is left to the submitter to do in # `update_params`. def loss_fn( @@ -137,7 +137,7 @@ def loss_fn( 'n_valid_examples': torch.as_tensor(n_valid_examples, device=DEVICE), 'per_example': per_example_losses, } - + def _eval_model_on_split(self, split: str, num_examples: int, From 345644b62831499142443a861eba0cf6bfa096f9 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Mon, 2 Oct 2023 18:52:38 +0000 Subject: [PATCH 41/55] import fix --- .../librispeech_deepspeech/librispeech_pytorch/workload.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py index 52663b854..1f8b57b43 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py @@ -1,7 +1,7 @@ import contextlib import functools import math -from typing import Dict, Iterator, Optional, Tuple +from typing import Dict, Optional, Tuple import torch import torch.distributed as dist From 36a0a73073c489d25f0c7491e6940d55c9b1dd8d Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Mon, 2 Oct 2023 19:12:05 +0000 Subject: [PATCH 42/55] test --- .../librispeech_deepspeech/librispeech_pytorch/workload.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py index 1f8b57b43..bda00905e 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py @@ -1,7 +1,7 @@ import contextlib import functools import math -from typing import Dict, Optional, Tuple +from typing import Dict,git pu Optional, Tuple import torch import torch.distributed as dist From 33e1896a6793b3b38d6aaec819c89db8444f9055 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Mon, 2 Oct 2023 19:12:41 +0000 Subject: [PATCH 43/55] add import fix --- .../librispeech_jax/workload.py | 59 ++++++++++--------- 1 file changed, 30 insertions(+), 29 deletions(-) diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py index 1ef38de82..82ed73416 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py @@ -1,8 +1,9 @@ import functools import math -from typing import Dict, Iterator, Optional, Tuple +from typing import Dict, Optional, Tuple from flax import jax_utils +import flax.linen as nn import jax import jax.numpy as jnp import numpy as np @@ -53,33 +54,33 @@ def is_output_params(self, param_key: spec.ParameterKey) -> bool: return param_key == 'Dense_0' def model_fn( - self, - params: spec.ParameterContainer, - augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], - model_state: spec.ModelAuxiliaryState, - mode: spec.ForwardPassMode, - rng: spec.RandomState, - update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: - variables = {'params': params, **model_state} - inputs, input_paddings = augmented_and_preprocessed_input_batch['inputs'] - is_train_mode = mode == spec.ForwardPassMode.TRAIN - if update_batch_norm or is_train_mode: - (logits, logit_paddings), new_model_state = self._model.apply( - variables, - inputs, - input_paddings, - train=True, - rngs={'dropout' : rng}, - mutable=['batch_stats']) - return (logits, logit_paddings), new_model_state - else: - logits, logit_paddings = self._model.apply( - variables, - inputs, - input_paddings, - train=False, - mutable=False) - return (logits, logit_paddings), model_state + self, + params: spec.ParameterContainer, + augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + mode: spec.ForwardPassMode, + rng: spec.RandomState, + update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + variables = {'params': params, **model_state} + inputs, input_paddings = augmented_and_preprocessed_input_batch['inputs'] + is_train_mode = mode == spec.ForwardPassMode.TRAIN + if update_batch_norm or is_train_mode: + (logits, logit_paddings), new_model_state = self._model.apply( + variables, + inputs, + input_paddings, + train=True, + rngs={'dropout' : rng}, + mutable=['batch_stats']) + return (logits, logit_paddings), new_model_state + else: + logits, logit_paddings = self._model.apply( + variables, + inputs, + input_paddings, + train=False, + mutable=False) + return (logits, logit_paddings), model_state # Does NOT apply regularization, which is left to the submitter to do in # `update_params`. @@ -153,4 +154,4 @@ def _eval_model_on_split(self, computed_metrics = metrics_report.compute() - return computed_metrics + return computed_metrics \ No newline at end of file From 2b40f5f91c1bd73f0486bbd6b8ddeb9283bcf234 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Mon, 2 Oct 2023 19:20:23 +0000 Subject: [PATCH 44/55] fix --- .../librispeech_jax/workload.py | 56 +++++++++---------- .../librispeech_pytorch/workload.py | 2 +- 2 files changed, 29 insertions(+), 29 deletions(-) diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py index 82ed73416..ae4c646ad 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py @@ -54,33 +54,33 @@ def is_output_params(self, param_key: spec.ParameterKey) -> bool: return param_key == 'Dense_0' def model_fn( - self, - params: spec.ParameterContainer, - augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], - model_state: spec.ModelAuxiliaryState, - mode: spec.ForwardPassMode, - rng: spec.RandomState, - update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: - variables = {'params': params, **model_state} - inputs, input_paddings = augmented_and_preprocessed_input_batch['inputs'] - is_train_mode = mode == spec.ForwardPassMode.TRAIN - if update_batch_norm or is_train_mode: - (logits, logit_paddings), new_model_state = self._model.apply( - variables, - inputs, - input_paddings, - train=True, - rngs={'dropout' : rng}, - mutable=['batch_stats']) - return (logits, logit_paddings), new_model_state - else: - logits, logit_paddings = self._model.apply( - variables, - inputs, - input_paddings, - train=False, - mutable=False) - return (logits, logit_paddings), model_state + self, + params: spec.ParameterContainer, + augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + mode: spec.ForwardPassMode, + rng: spec.RandomState, + update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + variables = {'params': params, **model_state} + inputs, input_paddings = augmented_and_preprocessed_input_batch['inputs'] + is_train_mode = mode == spec.ForwardPassMode.TRAIN + if update_batch_norm or is_train_mode: + (logits, logit_paddings), new_model_state = self._model.apply( + variables, + inputs, + input_paddings, + train=True, + rngs={'dropout' : rng}, + mutable=['batch_stats']) + return (logits, logit_paddings), new_model_state + else: + logits, logit_paddings = self._model.apply( + variables, + inputs, + input_paddings, + train=False, + mutable=False) + return (logits, logit_paddings), model_state # Does NOT apply regularization, which is left to the submitter to do in # `update_params`. @@ -154,4 +154,4 @@ def _eval_model_on_split(self, computed_metrics = metrics_report.compute() - return computed_metrics \ No newline at end of file + return computed_metrics diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py index bda00905e..1f8b57b43 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py @@ -1,7 +1,7 @@ import contextlib import functools import math -from typing import Dict,git pu Optional, Tuple +from typing import Dict, Optional, Tuple import torch import torch.distributed as dist From 29c123e84a4798d3b5cb33db1f97f17d4b4b2fab Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Mon, 2 Oct 2023 20:28:10 +0000 Subject: [PATCH 45/55] fix --- .../workloads/librispeech_deepspeech/workload.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/workload.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/workload.py index 0ac3bf422..f9fd30b0d 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/workload.py @@ -1,7 +1,7 @@ from algorithmic_efficiency.workloads.librispeech_conformer import workload -class BaseDeepspeechLibrispeechWorkload(workload.LibrispeechConformerWorkload): +class BaseDeepspeechLibrispeechWorkload(workload.BaseLibrispeechWorkload): @property def validation_target_value(self) -> float: From 0229fd8c2bd35328380d2a3d57154a6c0af3344e Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Mon, 2 Oct 2023 21:29:27 +0000 Subject: [PATCH 46/55] add init for deepspeech workloads --- .../librispeech_deepspeech/librispeech_jax/workload.py | 7 +++++++ .../librispeech_deepspeech/librispeech_pytorch/workload.py | 7 +++++++ 2 files changed, 14 insertions(+) diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py index ae4c646ad..b5b067336 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py @@ -17,6 +17,13 @@ class LibriSpeechDeepSpeechWorkload(BaseDeepspeechLibrispeechWorkload): + + def __init__(self, + tokenizer_vocab_path: Optional[str] = None, + use_specaug: bool = True) -> None: + super().__init__() + self.metrics_bundle = metrics.get_metrics_bundle(tokenizer_vocab_path) + self.use_specaug = use_specaug def init_model_fn( self, diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py index 1f8b57b43..1ad38c344 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py @@ -29,6 +29,13 @@ class LibriSpeechDeepSpeechWorkload(BaseDeepspeechLibrispeechWorkload): + def __init__(self, + tokenizer_vocab_path: Optional[str] = None, + use_specaug: bool = True) -> None: + super().__init__() + self.tokenizer = metrics.load_tokenizer(tokenizer_vocab_path) + self.use_specaug = use_specaug + def init_model_fn( self, rng: spec.RandomState, From 94854226bc15505352c33af47884f4710cd9ab32 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Mon, 2 Oct 2023 22:08:11 +0000 Subject: [PATCH 47/55] missing import --- .../workloads/librispeech_deepspeech/librispeech_jax/workload.py | 1 + 1 file changed, 1 insertion(+) diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py index b5b067336..55c6e1f1b 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py @@ -10,6 +10,7 @@ from algorithmic_efficiency import param_utils from algorithmic_efficiency import spec +from algorithmic_efficiency.workloads.librispeech_conformer import metrics from algorithmic_efficiency.workloads.librispeech_deepspeech.librispeech_jax import \ models from algorithmic_efficiency.workloads.librispeech_deepspeech.workload import \ From 8f85841aa1563a2045c5a0a042706c098a8ea741 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Mon, 2 Oct 2023 23:49:07 +0000 Subject: [PATCH 48/55] clean up deepspeech refactoring --- .../librispeech_jax/workload.py | 17 ++ .../librispeech_pytorch/workload.py | 17 ++ .../librispeech_jax/workload.py | 121 +------------ .../librispeech_pytorch/workload.py | 162 +----------------- .../librispeech_deepspeech/workload.py | 21 --- 5 files changed, 42 insertions(+), 296 deletions(-) delete mode 100644 algorithmic_efficiency/workloads/librispeech_deepspeech/workload.py diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py index 45d77ede4..4968aa881 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py @@ -344,3 +344,20 @@ def sync_batch_stats( new_model_state = model_state.copy( {'batch_stats': avg_fn(model_state['batch_stats'])}) return new_model_state + + @property + def validation_target_value(self) -> float: + return 0.118232 + + @property + def test_target_value(self) -> float: + return 0.073397 + + @property + def step_hint(self) -> int: + """Max num steps the baseline algo was given to reach the target.""" + return 48_000 + + @property + def max_allowed_runtime_sec(self) -> int: + return 55_506 # ~15.4 hours diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py index 24f4eb1fc..e9b0ed2b4 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py @@ -314,3 +314,20 @@ def _eval_model_on_split(self, float(total_metrics['word_errors'].item() / total_metrics['num_words'].item()), } + + @property + def validation_target_value(self) -> float: + return 0.118232 + + @property + def test_target_value(self) -> float: + return 0.073397 + + @property + def step_hint(self) -> int: + """Max num steps the baseline algo was given to reach the target.""" + return 48_000 + + @property + def max_allowed_runtime_sec(self) -> int: + return 55_506 # ~15.4 hours \ No newline at end of file diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py index 55c6e1f1b..4086a5841 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py @@ -1,30 +1,20 @@ import functools -import math -from typing import Dict, Optional, Tuple +from typing import Optional from flax import jax_utils -import flax.linen as nn import jax import jax.numpy as jnp import numpy as np from algorithmic_efficiency import param_utils from algorithmic_efficiency import spec -from algorithmic_efficiency.workloads.librispeech_conformer import metrics +from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_jax.workload import \ + LibriSpeechConformerWorkload from algorithmic_efficiency.workloads.librispeech_deepspeech.librispeech_jax import \ models -from algorithmic_efficiency.workloads.librispeech_deepspeech.workload import \ - BaseDeepspeechLibrispeechWorkload -class LibriSpeechDeepSpeechWorkload(BaseDeepspeechLibrispeechWorkload): - - def __init__(self, - tokenizer_vocab_path: Optional[str] = None, - use_specaug: bool = True) -> None: - super().__init__() - self.metrics_bundle = metrics.get_metrics_bundle(tokenizer_vocab_path) - self.use_specaug = use_specaug +class LibriSpeechDeepSpeechWorkload(LibriSpeechConformerWorkload): def init_model_fn( self, @@ -60,106 +50,3 @@ def init_model_fn( def is_output_params(self, param_key: spec.ParameterKey) -> bool: return param_key == 'Dense_0' - - def model_fn( - self, - params: spec.ParameterContainer, - augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], - model_state: spec.ModelAuxiliaryState, - mode: spec.ForwardPassMode, - rng: spec.RandomState, - update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: - variables = {'params': params, **model_state} - inputs, input_paddings = augmented_and_preprocessed_input_batch['inputs'] - is_train_mode = mode == spec.ForwardPassMode.TRAIN - if update_batch_norm or is_train_mode: - (logits, logit_paddings), new_model_state = self._model.apply( - variables, - inputs, - input_paddings, - train=True, - rngs={'dropout' : rng}, - mutable=['batch_stats']) - return (logits, logit_paddings), new_model_state - else: - logits, logit_paddings = self._model.apply( - variables, - inputs, - input_paddings, - train=False, - mutable=False) - return (logits, logit_paddings), model_state - - # Does NOT apply regularization, which is left to the submitter to do in - # `update_params`. - def loss_fn( - self, - label_batch: Tuple[spec.Tensor, spec.Tensor], # (label_batch, padding) - logits_batch: Tuple[spec.Tensor, spec.Tensor], # (logits_batch, padding) - mask_batch: Optional[spec.Tensor] = None, - label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: # differentiable - """Evaluate the (masked) loss function at (label_batch, logits_batch). - - Return {'summed': scalar summed loss, 'n_valid_examples': scalar number of - valid examples in batch, 'per_example': 1-d array of per-example losses} - (not synced across devices). - """ - del label_smoothing - logits, logit_paddings = logits_batch - targets, target_paddings = label_batch - logprobs = nn.log_softmax(logits) - per_example_losses = self.ctc_loss(logprobs, - logit_paddings, - targets, - target_paddings) - # mask_batch is assumed to be shape [batch]. - if mask_batch is not None: - per_example_losses *= mask_batch - mask_batch = jnp.logical_and(mask_batch, 1 - target_paddings) - else: - mask_batch = 1 - target_paddings - n_valid_examples = jnp.maximum(mask_batch.sum(), 1) - summed_loss = per_example_losses.sum() - return { - 'summed': summed_loss, - 'n_valid_examples': n_valid_examples, - 'per_example': per_example_losses, - } - - def _eval_model_on_split(self, - split: str, - num_examples: int, - global_batch_size: int, - params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - rng: spec.RandomState, - data_dir: str, - global_step: int = 0) -> Dict[str, float]: - """Run a full evaluation of the model.""" - del global_step - if model_state is not None: - # Sync batch statistics across replicas before evaluating. - model_state = self.sync_batch_stats(model_state) - - num_batches = int(math.ceil(num_examples / global_batch_size)) - if split not in self._eval_iters: - self._eval_iters[split] = self._build_input_queue( - rng, split, data_dir, global_batch_size, num_batches=num_batches) - - metrics_report = None - for _ in range(num_batches): - eval_batch = next(self._eval_iters[split]) - computed_metrics = self.eval_step_pmapped(params, - eval_batch, - model_state, - rng).unreplicate() - - if metrics_report is None: - metrics_report = computed_metrics - else: - # `merge` aggregates the metrics across batches. - metrics_report = metrics_report.merge(computed_metrics) - - computed_metrics = metrics_report.compute() - - return computed_metrics diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py index 1ad38c344..1bb649ba8 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py @@ -1,40 +1,26 @@ -import contextlib -import functools -import math -from typing import Dict, Optional, Tuple +from typing import Optional import torch -import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP from algorithmic_efficiency import param_utils -from algorithmic_efficiency import pytorch_utils from algorithmic_efficiency import spec from algorithmic_efficiency.pytorch_utils import pytorch_setup -import algorithmic_efficiency.random_utils as prng -from algorithmic_efficiency.workloads.librispeech_conformer import metrics from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_pytorch.models import \ initialize +from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_pytorch.workload import \ + LibriSpeechConformerWorkload from algorithmic_efficiency.workloads.librispeech_deepspeech.librispeech_pytorch.models import \ DeepspeechConfig from algorithmic_efficiency.workloads.librispeech_deepspeech.librispeech_pytorch.models import \ DeepspeechEncoderDecoder -from algorithmic_efficiency.workloads.librispeech_deepspeech.workload import \ - BaseDeepspeechLibrispeechWorkload USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_setup() MAX_INPUT_LENGTH = 320000 -class LibriSpeechDeepSpeechWorkload(BaseDeepspeechLibrispeechWorkload): - - def __init__(self, - tokenizer_vocab_path: Optional[str] = None, - use_specaug: bool = True) -> None: - super().__init__() - self.tokenizer = metrics.load_tokenizer(tokenizer_vocab_path) - self.use_specaug = use_specaug +class LibriSpeechDeepSpeechWorkload(LibriSpeechConformerWorkload): def init_model_fn( self, @@ -73,143 +59,3 @@ def init_model_fn( def is_output_params(self, param_key: spec.ParameterKey) -> bool: return param_key in ['lin.weight', 'lin.bias'] - - def model_fn( - self, - params: spec.ParameterContainer, - augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], - model_state: spec.ModelAuxiliaryState, - mode: spec.ForwardPassMode, - rng: spec.RandomState, - update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: - del model_state - del rng - - model = params - if mode == spec.ForwardPassMode.EVAL: - model.eval() - if mode == spec.ForwardPassMode.TRAIN: - model.train() - model.apply( - functools.partial( - pytorch_utils.update_batch_norm_fn, - update_batch_norm=update_batch_norm)) - - contexts = { - spec.ForwardPassMode.EVAL: torch.no_grad, - spec.ForwardPassMode.TRAIN: contextlib.nullcontext, - } - with contexts[mode](): - inputs, input_paddings = augmented_and_preprocessed_input_batch['inputs'] - logits, logits_paddings = model(inputs.to(DEVICE), - input_paddings.to(DEVICE)) - return (logits, logits_paddings), None - - # Does NOT apply regularization, which is left to the submitter to do in - # `update_params`. - def loss_fn( - self, - label_batch: Tuple[spec.Tensor, spec.Tensor], # (label_batch, padding) - logits_batch: Tuple[spec.Tensor, spec.Tensor], # (logits_batch, padding) - mask_batch: Optional[spec.Tensor] = None, - label_smoothing: float = 0.0) -> Dict[str, spec.Tensor]: # differentiable - """Evaluate the (masked) loss function at (label_batch, logits_batch). - - Return {'summed': scalar summed loss, 'n_valid_examples': scalar number of - valid examples in batch, 'per_example': 1-d array of per-example losses} - (not synced across devices). - """ - del label_smoothing - targets, target_paddings = label_batch - logits, logit_paddings = logits_batch - logprobs = torch.log_softmax(logits, dim=-1) - input_lengths = torch.einsum('bh->b', 1 - logit_paddings).long() - target_lengths = torch.einsum('bh->b', 1 - target_paddings).long() - per_example_losses = self.ctc_loss( - logprobs.permute(1, 0, 2), - targets.long(), - input_lengths, - target_lengths) - # mask_batch is assumed to be shape [batch]. - if mask_batch is not None: - per_example_losses *= mask_batch - mask_batch = torch.logical_and(mask_batch, target_lengths) - else: - mask_batch = target_lengths - n_valid_examples = mask_batch.sum().to(per_example_losses) - summed_loss = per_example_losses.sum() - n_valid_examples = max(n_valid_examples, 1) - return { - 'summed': summed_loss, - 'n_valid_examples': torch.as_tensor(n_valid_examples, device=DEVICE), - 'per_example': per_example_losses, - } - - def _eval_model_on_split(self, - split: str, - num_examples: int, - global_batch_size: int, - params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - rng: spec.RandomState, - data_dir: str, - global_step: int = 0) -> Dict[str, float]: - """Run a full evaluation of the model.""" - del global_step - data_rng, model_rng = prng.split(rng, 2) - if split not in self._eval_iters: - # These iterators repeat indefinitely. - self._eval_iters[split] = ( - self._build_input_queue( - data_rng, split, data_dir, global_batch_size=global_batch_size)) - - total_metrics = { - 'loss': torch.tensor(0., device=DEVICE), - 'lengths': torch.tensor(0., device=DEVICE), - 'word_errors': torch.tensor(0., device=DEVICE), - 'num_words': torch.tensor(0., device=DEVICE), - } - num_batches = int(math.ceil(num_examples / global_batch_size)) - if self.requires_sync_before_eval: - self.sync_sd(params) - for _ in range(num_batches): - batch = next(self._eval_iters[split]) - - (logits, logits_padding), _ = self.model_fn( - params, - batch, - model_state, - spec.ForwardPassMode.EVAL, - model_rng, - update_batch_norm=False) - decoded, decoded_paddings = self.greedy_decode(logits, logits_padding) - targets, target_paddings = batch['targets'] - word_errors, num_words = metrics.compute_wer( - decoded=decoded.cpu().numpy(), - decoded_paddings=decoded_paddings.cpu().numpy(), - targets=targets.cpu().numpy(), - target_paddings=target_paddings.cpu().numpy(), - tokenizer=self.tokenizer) - loss = self.loss_fn((targets, target_paddings), (logits, logits_padding)) - summed_loss = loss['summed'] - lengths = loss['n_valid_examples'] - batch_metrics = { - 'loss': summed_loss, - 'lengths': lengths, - 'word_errors': word_errors, - 'num_words': num_words, - } - total_metrics = { - k: v + batch_metrics[k] for k, v in total_metrics.items() - } - if USE_PYTORCH_DDP: - for metric in total_metrics.values(): - dist.all_reduce(metric) - return { - 'ctc_loss': - float(total_metrics['loss'].item() / - total_metrics['lengths'].item()), - 'wer': - float(total_metrics['word_errors'].item() / - total_metrics['num_words'].item()), - } diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/workload.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/workload.py deleted file mode 100644 index f9fd30b0d..000000000 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/workload.py +++ /dev/null @@ -1,21 +0,0 @@ -from algorithmic_efficiency.workloads.librispeech_conformer import workload - - -class BaseDeepspeechLibrispeechWorkload(workload.BaseLibrispeechWorkload): - - @property - def validation_target_value(self) -> float: - return 0.118232 - - @property - def test_target_value(self) -> float: - return 0.073397 - - @property - def step_hint(self) -> int: - """Max num steps the baseline algo was given to reach the target.""" - return 48_000 - - @property - def max_allowed_runtime_sec(self) -> int: - return 55_506 # ~15.4 hours From 5ea7b13ace1e4f8e9d7127c9ef9c6c54388aeb1e Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Mon, 2 Oct 2023 23:53:04 +0000 Subject: [PATCH 49/55] fix lint --- .../librispeech_conformer/librispeech_pytorch/workload.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py index e9b0ed2b4..3024db04f 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py @@ -330,4 +330,4 @@ def step_hint(self) -> int: @property def max_allowed_runtime_sec(self) -> int: - return 55_506 # ~15.4 hours \ No newline at end of file + return 55_506 # ~15.4 hours From fd1e49cb09ea2bd12917259fbab3fc68c1e8a9fa Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Mon, 2 Oct 2023 23:57:34 +0000 Subject: [PATCH 50/55] add targets to deepspeech --- .../librispeech_jax/workload.py | 17 ----------------- .../librispeech_pytorch/workload.py | 16 ---------------- .../librispeech_jax/workload.py | 17 +++++++++++++++++ .../librispeech_pytorch/workload.py | 17 +++++++++++++++++ 4 files changed, 34 insertions(+), 33 deletions(-) diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py index 4968aa881..45d77ede4 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py @@ -344,20 +344,3 @@ def sync_batch_stats( new_model_state = model_state.copy( {'batch_stats': avg_fn(model_state['batch_stats'])}) return new_model_state - - @property - def validation_target_value(self) -> float: - return 0.118232 - - @property - def test_target_value(self) -> float: - return 0.073397 - - @property - def step_hint(self) -> int: - """Max num steps the baseline algo was given to reach the target.""" - return 48_000 - - @property - def max_allowed_runtime_sec(self) -> int: - return 55_506 # ~15.4 hours diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py index 3024db04f..465aeacc0 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py @@ -315,19 +315,3 @@ def _eval_model_on_split(self, total_metrics['num_words'].item()), } - @property - def validation_target_value(self) -> float: - return 0.118232 - - @property - def test_target_value(self) -> float: - return 0.073397 - - @property - def step_hint(self) -> int: - """Max num steps the baseline algo was given to reach the target.""" - return 48_000 - - @property - def max_allowed_runtime_sec(self) -> int: - return 55_506 # ~15.4 hours diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py index 4086a5841..cb07075f4 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py @@ -50,3 +50,20 @@ def init_model_fn( def is_output_params(self, param_key: spec.ParameterKey) -> bool: return param_key == 'Dense_0' + + @property + def validation_target_value(self) -> float: + return 0.118232 + + @property + def test_target_value(self) -> float: + return 0.073397 + + @property + def step_hint(self) -> int: + """Max num steps the baseline algo was given to reach the target.""" + return 48_000 + + @property + def max_allowed_runtime_sec(self) -> int: + return 55_506 # ~15.4 hours diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py index 1bb649ba8..6c9d31ce0 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py @@ -59,3 +59,20 @@ def init_model_fn( def is_output_params(self, param_key: spec.ParameterKey) -> bool: return param_key in ['lin.weight', 'lin.bias'] + + @property + def validation_target_value(self) -> float: + return 0.118232 + + @property + def test_target_value(self) -> float: + return 0.073397 + + @property + def step_hint(self) -> int: + """Max num steps the baseline algo was given to reach the target.""" + return 48_000 + + @property + def max_allowed_runtime_sec(self) -> int: + return 55_506 # ~15.4 hours \ No newline at end of file From 42745388ffd2d5f0d045b8e5551bd84f860f3659 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Mon, 2 Oct 2023 23:58:23 +0000 Subject: [PATCH 51/55] lint --- .../librispeech_conformer/librispeech_pytorch/workload.py | 1 - .../librispeech_deepspeech/librispeech_jax/workload.py | 1 + .../librispeech_deepspeech/librispeech_pytorch/workload.py | 3 ++- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py index 465aeacc0..24f4eb1fc 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py @@ -314,4 +314,3 @@ def _eval_model_on_split(self, float(total_metrics['word_errors'].item() / total_metrics['num_words'].item()), } - diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py index cb07075f4..73ea57e5c 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py @@ -67,3 +67,4 @@ def step_hint(self) -> int: @property def max_allowed_runtime_sec(self) -> int: return 55_506 # ~15.4 hours + \ No newline at end of file diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py index 6c9d31ce0..e7ae3f752 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py @@ -75,4 +75,5 @@ def step_hint(self) -> int: @property def max_allowed_runtime_sec(self) -> int: - return 55_506 # ~15.4 hours \ No newline at end of file + return 55_506 # ~15.4 hours + \ No newline at end of file From 09eed7e01e4250d14becb9053cb71c37e84c62aa Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Tue, 3 Oct 2023 00:30:39 +0000 Subject: [PATCH 52/55] formatting --- .../workloads/librispeech_deepspeech/librispeech_jax/workload.py | 1 - .../librispeech_deepspeech/librispeech_pytorch/workload.py | 1 - 2 files changed, 2 deletions(-) diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py index 73ea57e5c..cb07075f4 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py @@ -67,4 +67,3 @@ def step_hint(self) -> int: @property def max_allowed_runtime_sec(self) -> int: return 55_506 # ~15.4 hours - \ No newline at end of file diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py index e7ae3f752..09a4b0aa4 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/workload.py @@ -76,4 +76,3 @@ def step_hint(self) -> int: @property def max_allowed_runtime_sec(self) -> int: return 55_506 # ~15.4 hours - \ No newline at end of file From a66413401b87a1379a29e6c7e0abbe8c72a90a71 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 5 Oct 2023 12:56:04 -0700 Subject: [PATCH 53/55] Change padding for Deepspeech LSTM layer Remove global_batch_size arg in call to shard_and_maybe_pad batch call. This will result in the final batch of the validation and test sets for librispeech being just padded just enough so that it can be split equally amongst the devices. So we will not have device batches containing all padding. Workaround for https://github.com/mlcommons/algorithmic-efficiency/issues/523. --- .../workloads/librispeech_conformer/librispeech_jax/workload.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py index 45d77ede4..bc7eae3b8 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py @@ -144,7 +144,7 @@ def _build_input_queue( } padded_batch = data_utils.shard_and_maybe_pad_np( - numpy_batch, padding_value=1.0, global_batch_size=global_batch_size) + numpy_batch, padding_value=1.0) yield padded_batch # Does NOT apply regularization, which is left to the submitter to do in From 2cb31a2e3ca4c0ed29a45a87fd5eba4c788ff954 Mon Sep 17 00:00:00 2001 From: runame Date: Fri, 6 Oct 2023 12:38:12 +0200 Subject: [PATCH 54/55] Adjust runtime budget for self-tuning ruleset and check that tuning search space is None --- submission_runner.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/submission_runner.py b/submission_runner.py index 717ea2dc4..47730d3fc 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -340,9 +340,12 @@ def train_once( train_state['accumulated_submission_time'] += ( train_step_end_time - train_state['last_step_end_time']) + # Use 3x the runtime budget for the self-tuning ruleset. + max_allowed_runtime_sec = ( + workload.max_allowed_runtime_sec if FLAGS.tuning_ruleset == 'external' + else 3 * workload.max_allowed_runtime_sec) train_state['is_time_remaining'] = ( - train_state['accumulated_submission_time'] < - workload.max_allowed_runtime_sec) + train_state['accumulated_submission_time'] < max_allowed_runtime_sec) # Check if submission is eligible for an untimed eval. if ((train_step_end_time - train_state['last_eval_time']) >= workload.eval_period_time_sec or train_state['training_complete']): @@ -567,6 +570,9 @@ def score_submission_on_workload(workload: spec.Workload, logging.info(f'Total number of evals: {num_evals}') logging.info('=' * 20) else: + if tuning_search_space is not None: + raise ValueError( + 'Cannot provide a tuning search space when using self tuning.') if not rng_seed: rng_seed = struct.unpack('q', os.urandom(8))[0] rng = prng.PRNGKey(rng_seed) From 5407ab1e78147f459af6d771426be8891348a960 Mon Sep 17 00:00:00 2001 From: runame Date: Fri, 6 Oct 2023 13:01:19 +0200 Subject: [PATCH 55/55] Remove test target from scoring --- scoring/scoring.py | 21 +++++---------------- 1 file changed, 5 insertions(+), 16 deletions(-) diff --git a/scoring/scoring.py b/scoring/scoring.py index fff152255..12aae1357 100644 --- a/scoring/scoring.py +++ b/scoring/scoring.py @@ -53,7 +53,7 @@ def generate_eval_cols(metrics): - splits = ['train', 'validation', 'test'] + splits = ['train', 'validation'] return [f'{split}/{col}' for split, col in itertools.product(splits, metrics)] @@ -108,15 +108,13 @@ def get_index_that_reaches_best(workload_df, metric_col): def get_index_that_reaches_target(workload_df, validation_metric, - test_metric, - validation_target, - test_target): + validation_target): """Get the eval index in which a workload reaches the target metric_col. Args: workload_df: A subset of a submission's trials DataFrame that includes only the trials in a single workload. - metric_col: Name of array column in workload_df (e.g., `validation/l1_loss`). + metric_col: Name of array column in workload_df (e.g. `validation/l1_loss`). target: Target value for metric_col. Returns: @@ -125,20 +123,13 @@ def get_index_that_reaches_target(workload_df, """ is_minimized = check_if_minimized(validation_metric) validation_series = workload_df[validation_metric] - test_series = workload_df[test_metric] - validation_series = validation_series[validation_series != np.nan] - validation_series = validation_series[test_series != np.nan] - test_series = test_series[validation_series != np.nan] - test_series = test_series[test_series != np.nan] op = operator.le if is_minimized else operator.ge validation_target_reached = validation_series.apply( lambda x: op(x, validation_target)) - test_target_reached = test_series.apply(lambda x: op(x, test_target)) - target_reached = pd.Series(validation_target_reached[0] - & test_target_reached[0]) + target_reached = pd.Series(validation_target_reached[0]) # Remove trials that never reach the target target_reached = target_reached[target_reached.apply(np.any)] @@ -188,12 +179,10 @@ def get_times_for_submission(submission, workload_init_kwargs=workload_init_kwargs) metric_name = workload_obj.target_metric_name validation_metric = f'validation/{metric_name}' - test_metric = f'test/{metric_name}' validation_target = workload_obj.validation_target_value - test_target = workload_obj.test_target_value trial_idx, time_idx = get_index_that_reaches_target( - group, validation_metric, test_metric, validation_target, test_target) + group, validation_metric, validation_target) if time_idx > -1: time_val = group[time_col].loc[trial_idx][time_idx] else: