From fa7359b85ec4459cd76c2ef932beb339c2eab708 Mon Sep 17 00:00:00 2001 From: Zack Nado Date: Fri, 7 Oct 2022 17:09:30 -0400 Subject: [PATCH] saving work refactoring target_setting_runs/ --- target_setting_runs/README.md | 138 ++++++++++++++++++ target_setting_runs/cosine_warmup.py | 37 +++++ target_setting_runs/criteo1tb/__init__.py | 0 .../criteo1tb/pytorch_submission.py | 0 .../criteo1tb/tuning_search_space.json | 2 +- target_setting_runs/fastmri/__init__.py | 0 target_setting_runs/fastmri/jax_submission.py | 0 target_setting_runs/get_batch_size.py | 23 +++ .../imagenet_resnet/__init__.py | 0 .../imagenet_resnet/jax_submission.py | 94 ------------ .../imagenet_resnet/pytorch_submission.py | 63 -------- target_setting_runs/imagenet_vit/__init__.py | 0 .../imagenet_vit/jax_submission.py | 90 ------------ .../imagenet_vit/pytorch_submission.py | 61 -------- target_setting_runs/jax_adamw.py | 99 +------------ target_setting_runs/jax_nadamw.py | 76 +++++----- target_setting_runs/jax_nesterov.py | 8 + ...x_submission.py => jax_submission_base.py} | 38 +++-- .../librispeech_conformer/__init__.py | 0 .../librispeech_conformer/jax_submission.py | 0 .../pytorch_submission.py | 0 .../librispeech_deepspeech/__init__.py | 0 .../librispeech_deepspeech/jax_submission.py | 0 .../pytorch_submission.py | 0 target_setting_runs/ogbg/__init__.py | 0 target_setting_runs/ogbg/jax_submission.py | 92 ------------ .../ogbg/pytorch_submission.py | 63 -------- target_setting_runs/pytorch_adamw.py | 28 ++-- target_setting_runs/pytorch_nadamw.py | 77 +++++----- target_setting_runs/pytorch_nesterov.py | 6 + ...bmission.py => pytorch_submission_base.py} | 7 +- target_setting_runs/wmt/__init__.py | 0 target_setting_runs/wmt/jax_submission.py | 94 ------------ target_setting_runs/wmt/pytorch_submission.py | 63 -------- 34 files changed, 325 insertions(+), 834 deletions(-) create mode 100644 target_setting_runs/README.md create mode 100644 target_setting_runs/cosine_warmup.py delete mode 100644 target_setting_runs/criteo1tb/__init__.py delete mode 100644 target_setting_runs/criteo1tb/pytorch_submission.py delete mode 100644 target_setting_runs/fastmri/__init__.py delete mode 100644 target_setting_runs/fastmri/jax_submission.py create mode 100644 target_setting_runs/get_batch_size.py delete mode 100644 target_setting_runs/imagenet_resnet/__init__.py delete mode 100644 target_setting_runs/imagenet_resnet/jax_submission.py delete mode 100644 target_setting_runs/imagenet_resnet/pytorch_submission.py delete mode 100644 target_setting_runs/imagenet_vit/__init__.py delete mode 100644 target_setting_runs/imagenet_vit/jax_submission.py delete mode 100644 target_setting_runs/imagenet_vit/pytorch_submission.py rename target_setting_runs/{criteo1tb/jax_submission.py => jax_submission_base.py} (73%) delete mode 100644 target_setting_runs/librispeech_conformer/__init__.py delete mode 100644 target_setting_runs/librispeech_conformer/jax_submission.py delete mode 100644 target_setting_runs/librispeech_conformer/pytorch_submission.py delete mode 100644 target_setting_runs/librispeech_deepspeech/__init__.py delete mode 100644 target_setting_runs/librispeech_deepspeech/jax_submission.py delete mode 100644 target_setting_runs/librispeech_deepspeech/pytorch_submission.py delete mode 100644 target_setting_runs/ogbg/__init__.py delete mode 100644 target_setting_runs/ogbg/jax_submission.py delete mode 100644 target_setting_runs/ogbg/pytorch_submission.py rename target_setting_runs/{fastmri/pytorch_submission.py => pytorch_submission_base.py} (86%) delete mode 100644 target_setting_runs/wmt/__init__.py delete mode 100644 target_setting_runs/wmt/jax_submission.py delete mode 100644 target_setting_runs/wmt/pytorch_submission.py diff --git a/target_setting_runs/README.md b/target_setting_runs/README.md new file mode 100644 index 000000000..542d42b1d --- /dev/null +++ b/target_setting_runs/README.md @@ -0,0 +1,138 @@ +# Target Setting Run replications +Original runs were run on Google TPUv2-8 machines. + +## Criteo +Target was set using AdamW with a linear warmup cosine decay LR schedule. +```bash +python3 submission_runner.py \ + --framework=jax \ + --workload=criteo1tb \ + --submission_path=target_setting_runs/jax_adamw.py \ + --tuning_search_space=target_setting_runs/criteo1tb/tuning_search_space.json +``` +```bash +python3 submission_runner.py \ + --framework=pytorch \ + --workload=criteo1tb \ + --submission_path=target_setting_runs/pytorch_adamw.py \ + --tuning_search_space=target_setting_runs/criteo1tb/tuning_search_space.json +``` + +# FastMRI +Target was set using NAdamW with a linear warmup cosine decay LR schedule. +```bash +python3 submission_runner.py \ + --framework=jax \ + --workload=fastmri \ + --submission_path=target_setting_runs/jax_nadamw.py \ + --tuning_search_space=target_setting_runs/fastmri/tuning_search_space.json +``` +```bash +python3 submission_runner.py \ + --framework=pytorch \ + --workload=fastmri \ + --submission_path=target_setting_runs/pytorch_nadamw.py \ + --tuning_search_space=target_setting_runs/fastmri/tuning_search_space.json +``` + +# ImageNet-Resnet +Target was set using Nesterov with a linear warmup and linear decay LR schedule. +```bash +python3 submission_runner.py \ + --framework=jax \ + --workload=imagenet_resnet \ + --submission_path=target_setting_runs/jax_nesterov.py \ + --tuning_search_space=target_setting_runs/imagenet_resnet/tuning_search_space.json +``` +```bash +python3 submission_runner.py \ + --framework=pytorch \ + --workload=imagenet_resnet \ + --submission_path=target_setting_runs/pytorch_nesterov.py \ + --tuning_search_space=target_setting_runs/imagenet_resnet/tuning_search_space.json +``` + +# ImageNet-ViT +Target was set using NAdamW with a linear warmup cosine decay LR schedule. +```bash +python3 submission_runner.py \ + --framework=jax \ + --workload=imagenet_vit \ + --submission_path=target_setting_runs/jax_nadamw.py \ + --tuning_search_space=target_setting_runs/imagenet_vit/tuning_search_space.json +``` +```bash +python3 submission_runner.py \ + --framework=pytorch \ + --workload=imagenet_vit \ + --submission_path=target_setting_runs/pytorch_nadamw.py \ + --tuning_search_space=target_setting_runs/imagenet_vit/tuning_search_space.json +``` + +# Librispeech-Conformer +Target was set using AdamW with a linear warmup cosine decay LR schedule. +```bash +python3 submission_runner.py \ + --framework=jax \ + --workload=librispeech_conformer \ + --submission_path=target_setting_runs/jax_adamw.py \ + --tuning_search_space=target_setting_runs/librispeech_conformer/tuning_search_space.json +``` +```bash +python3 submission_runner.py \ + --framework=pytorch \ + --workload=librispeech_conformer \ + --submission_path=target_setting_runs/pytorch_adamw.py \ + --tuning_search_space=target_setting_runs/librispeech_conformer/tuning_search_space.json +``` + +# Librispeech-Deepspeech +Target was set using NAdamW with a linear warmup cosine decay LR schedule. +```bash +python3 submission_runner.py \ + --framework=jax \ + --workload=librispeech_deepspeech \ + --submission_path=target_setting_runs/jax_nadamw.py \ + --tuning_search_space=target_setting_runs/librispeech_deepspeech/tuning_search_space.json +``` +```bash +python3 submission_runner.py \ + --framework=pytorch \ + --workload=librispeech_deepspeech \ + --submission_path=target_setting_runs/pytorch_nadamw.py \ + --tuning_search_space=target_setting_runs/librispeech_deepspeech/tuning_search_space.json +``` + +# OGBG +Target was set using Nesterov with a linear warmup and linear decay LR schedule. +```bash +python3 submission_runner.py \ + --framework=jax \ + --workload=ogbg \ + --submission_path=target_setting_runs/jax_nesterov.py \ + --tuning_search_space=target_setting_runs/ogbg/tuning_search_space.json +``` +```bash +python3 submission_runner.py \ + --framework=pytorch \ + --workload=ogbg \ + --submission_path=target_setting_runs/pytorch_nesterov.py \ + --tuning_search_space=target_setting_runs/ogbg/tuning_search_space.json +``` + +# WMT +Target was set using AdamW with a linear warmup cosine decay LR schedule. +```bash +python3 submission_runner.py \ + --framework=jax \ + --workload=wmt \ + --submission_path=target_setting_runs/jax_adamw.py \ + --tuning_search_space=target_setting_runs/wmt/tuning_search_space.json +``` +```bash +python3 submission_runner.py \ + --framework=pytorch \ + --workload=wmt \ + --submission_path=target_setting_runs/pytorch_adamw.py \ + --tuning_search_space=target_setting_runs/wmt/tuning_search_space.json +``` diff --git a/target_setting_runs/cosine_warmup.py b/target_setting_runs/cosine_warmup.py new file mode 100644 index 000000000..b3676de21 --- /dev/null +++ b/target_setting_runs/cosine_warmup.py @@ -0,0 +1,37 @@ +"""Implementions of a linear warmup then cosine decay LR schedule.""" + +import optax +from torch.optim.lr_scheduler import CosineAnnealingLR +from torch.optim.lr_scheduler import LinearLR +from torch.optim.lr_scheduler import SequentialLR + + +def jax_cosine_warmup(hyperparameters): + # Create learning rate schedule. + warmup_fn = optax.linear_schedule( + init_value=0., + end_value=hyperparameters.learning_rate, + transition_steps=hyperparameters.warmup_steps) + cosine_steps = max(hyperparameters.num_steps - hyperparameters.warmup_steps, + 1) + cosine_fn = optax.cosine_decay_schedule( + init_value=hyperparameters.learning_rate, decay_steps=cosine_steps) + schedule_fn = optax.join_schedules( + schedules=[warmup_fn, cosine_fn], + boundaries=[hyperparameters.warmup_steps]) + return schedule_fn + + +def pytorch_cosine_warmup(hyperparameters, optimizer): + warmup = LinearLR( + optimizer, + start_factor=1e-10, + end_factor=1., + total_iters=hyperparameters.warmup_steps) + cosine_steps = max(hyperparameters.num_steps - hyperparameters.warmup_steps, + 1) + cosine_decay = CosineAnnealingLR(optimizer, T_max=cosine_steps) + return SequentialLR( + optimizer, + schedulers=[warmup, cosine_decay], + milestones=[hyperparameters.warmup_steps]) diff --git a/target_setting_runs/criteo1tb/__init__.py b/target_setting_runs/criteo1tb/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/target_setting_runs/criteo1tb/pytorch_submission.py b/target_setting_runs/criteo1tb/pytorch_submission.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/target_setting_runs/criteo1tb/tuning_search_space.json b/target_setting_runs/criteo1tb/tuning_search_space.json index 781863a35..a5c423711 100644 --- a/target_setting_runs/criteo1tb/tuning_search_space.json +++ b/target_setting_runs/criteo1tb/tuning_search_space.json @@ -16,7 +16,7 @@ }, "warmup_steps": { "feasible_points": [ - 1600 + 200 ] }, "num_steps": { diff --git a/target_setting_runs/fastmri/__init__.py b/target_setting_runs/fastmri/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/target_setting_runs/fastmri/jax_submission.py b/target_setting_runs/fastmri/jax_submission.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/target_setting_runs/get_batch_size.py b/target_setting_runs/get_batch_size.py new file mode 100644 index 000000000..7b3bc0408 --- /dev/null +++ b/target_setting_runs/get_batch_size.py @@ -0,0 +1,23 @@ +"""Batch size selection submission function.""" + + +def get_batch_size(workload_name): + # Return the global batch size. + if workload_name == 'criteo1tb_dlrm': + return 524288 + elif workload_name == 'fastmri': + return 32 + elif workload_name == 'imagenet_resnet': + return 1024 + elif workload_name == 'imagenet_vit': + return 1024 + elif workload_name == 'librispeech_conformer': + return 256 + elif workload_name == 'librispeech_deepspeech': + return 256 + elif workload_name == 'ogbg': + return 512 + elif workload_name == 'wmt': + return 128 + else: + raise ValueError(f'Unsupported workload name: {workload_name}.') diff --git a/target_setting_runs/imagenet_resnet/__init__.py b/target_setting_runs/imagenet_resnet/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/target_setting_runs/imagenet_resnet/jax_submission.py b/target_setting_runs/imagenet_resnet/jax_submission.py deleted file mode 100644 index 619106580..000000000 --- a/target_setting_runs/imagenet_resnet/jax_submission.py +++ /dev/null @@ -1,94 +0,0 @@ -""" -Jax submission for the target-setting run on ImageNet-ResNet with Nesterov. -""" - -import functools -from typing import Dict, List, Tuple - -import jax -from jax import lax -import jax.numpy as jnp -import optax - -from algorithmic_efficiency import spec -from target_setting_runs.data_selection import \ - data_selection # pylint: disable=unused-import -from target_setting_runs.jax_nesterov import \ - init_optimizer_state # pylint: disable=unused-import - - -def get_batch_size(workload_name): - # Return the global batch size. - del workload_name - return 1024 - - -@functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, None, 0, 0, 0, 0, 0, None), - static_broadcasted_argnums=(0, 1)) -def pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - rng, - label_smoothing): - - def _loss_fn(params): - """Loss function used for training.""" - logits, new_model_state = workload.model_fn( - params, - batch, - model_state, - spec.ForwardPassMode.TRAIN, - rng, - dropout_rate=None, - aux_dropout_rate=None, - update_batch_norm=True) - loss = jnp.mean( - workload.loss_fn( - batch['targets'], logits, label_smoothing=label_smoothing)) - return loss, (new_model_state, logits) - - grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) - aux, grad = grad_fn(current_param_container) - grad = lax.pmean(grad, axis_name='batch') - new_model_state, _ = aux[1] - updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, - current_param_container) - updated_params = optax.apply_updates(current_param_container, updates) - - return new_model_state, new_optimizer_state, updated_params - - -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: - """Return (updated_optimizer_state, updated_params, updated_model_state).""" - del current_params_types - del loss_type - del eval_results - del global_step - - optimizer_state, opt_update_fn = optimizer_state - - per_device_rngs = jax.random.split(rng, jax.local_device_count()) - label_smoothing = ( - hyperparameters.label_smoothing if hasattr(hyperparameters, - 'label_smoothing') else 0.0) - new_model_state, new_optimizer_state, new_params = pmapped_train_step( - workload, opt_update_fn, model_state, optimizer_state, - current_param_container, batch, per_device_rngs, label_smoothing) - - return (new_optimizer_state, opt_update_fn), new_params, new_model_state diff --git a/target_setting_runs/imagenet_resnet/pytorch_submission.py b/target_setting_runs/imagenet_resnet/pytorch_submission.py deleted file mode 100644 index 5e611341e..000000000 --- a/target_setting_runs/imagenet_resnet/pytorch_submission.py +++ /dev/null @@ -1,63 +0,0 @@ -""" -PyTorch submission for the target-setting run on ImageNet-ResNet with Nesterov. -""" - -from typing import Dict, List, Tuple - -from algorithmic_efficiency import spec -from target_setting_runs.data_selection import \ - data_selection # pylint: disable=unused-import -from target_setting_runs.pytorch_nesterov import \ - init_optimizer_state # pylint: disable=unused-import - - -def get_batch_size(workload_name): - # Return the global batch size. - del workload_name - return 1024 - - -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: - """Return (updated_optimizer_state, updated_params, updated_model_state).""" - del current_params_types - del loss_type - del eval_results - del global_step - - current_model = current_param_container - current_model.train() - optimizer_state['optimizer'].zero_grad() - - logits_batch, new_model_state = workload.model_fn( - params=current_model, - augmented_and_preprocessed_input_batch=batch, - model_state=model_state, - mode=spec.ForwardPassMode.TRAIN, - rng=rng, - dropout_rate=None, - aux_dropout_rate=None, - update_batch_norm=True) - - label_smoothing = ( - hyperparameters.label_smoothing if hasattr(hyperparameters, - 'label_smoothing') else 0.0) - loss = workload.loss_fn( - label_batch=batch['targets'], - logits_batch=logits_batch, - label_smoothing=label_smoothing).mean() - - loss.backward() - optimizer_state['optimizer'].step() - optimizer_state['scheduler'].step() - - return (optimizer_state, current_param_container, new_model_state) diff --git a/target_setting_runs/imagenet_vit/__init__.py b/target_setting_runs/imagenet_vit/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/target_setting_runs/imagenet_vit/jax_submission.py b/target_setting_runs/imagenet_vit/jax_submission.py deleted file mode 100644 index a9bcee01b..000000000 --- a/target_setting_runs/imagenet_vit/jax_submission.py +++ /dev/null @@ -1,90 +0,0 @@ -"""Jax submission for the target-setting run on ImageNet-ViT with AdamW.""" - -import functools -from typing import Dict, List, Tuple - -import jax -from jax import lax -import jax.numpy as jnp -import optax - -from algorithmic_efficiency import spec -from target_setting_runs.data_selection import \ - data_selection # pylint: disable=unused-import -from target_setting_runs.jax_adamw import \ - init_optimizer_state # pylint: disable=unused-import - - -def get_batch_size(workload_name): - # Return the global batch size. - del workload_name - return 1024 - - -@functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, None, 0, 0, 0, 0, 0, None), - static_broadcasted_argnums=(0, 1)) -def pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - rng, - label_smoothing): - - def _loss_fn(params): - """loss function used for training.""" - logits, new_model_state = workload.model_fn( - params, - batch, - model_state, - spec.ForwardPassMode.TRAIN, - rng, - dropout_rate=0.0, # Default. - aux_dropout_rate=None, - update_batch_norm=True) - loss = jnp.mean( - workload.loss_fn( - batch['targets'], logits, label_smoothing=label_smoothing)) - return loss, (new_model_state, logits) - - grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) - aux, grad = grad_fn(current_param_container) - grad = lax.pmean(grad, axis_name='batch') - new_model_state, _ = aux[1] - updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, - current_param_container) - updated_params = optax.apply_updates(current_param_container, updates) - - return new_model_state, new_optimizer_state, updated_params - - -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: - """Return (updated_optimizer_state, updated_params, updated_model_state).""" - del current_params_types - del loss_type - del eval_results - del global_step - - optimizer_state, opt_update_fn = optimizer_state - per_device_rngs = jax.random.split(rng, jax.local_device_count()) - label_smoothing = ( - hyperparameters.label_smoothing if hasattr(hyperparameters, - 'label_smoothing') else 0.0) - new_model_state, new_optimizer_state, new_params = pmapped_train_step( - workload, opt_update_fn, model_state, optimizer_state, - current_param_container, batch, per_device_rngs, label_smoothing) - return (new_optimizer_state, opt_update_fn), new_params, new_model_state diff --git a/target_setting_runs/imagenet_vit/pytorch_submission.py b/target_setting_runs/imagenet_vit/pytorch_submission.py deleted file mode 100644 index 21ca940a7..000000000 --- a/target_setting_runs/imagenet_vit/pytorch_submission.py +++ /dev/null @@ -1,61 +0,0 @@ -"""PyTorch submission for the target-setting run on ImageNet-ViT with AdamW.""" - -from typing import Dict, List, Tuple - -from algorithmic_efficiency import spec -from target_setting_runs.data_selection import \ - data_selection # pylint: disable=unused-import -from target_setting_runs.pytorch_adamw import \ - init_optimizer_state # pylint: disable=unused-import - - -def get_batch_size(workload_name): - # Return the global batch size. - del workload_name - return 1024 - - -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: - """Return (updated_optimizer_state, updated_params, updated_model_state).""" - del current_params_types - del loss_type - del eval_results - del global_step - - current_model = current_param_container - current_model.train() - optimizer_state['optimizer'].zero_grad() - - logits_batch, new_model_state = workload.model_fn( - params=current_model, - augmented_and_preprocessed_input_batch=batch, - model_state=model_state, - mode=spec.ForwardPassMode.TRAIN, - rng=rng, - dropout_rate=0.0, # Default. - aux_dropout_rate=None, - update_batch_norm=True) - - label_smoothing = ( - hyperparameters.label_smoothing if hasattr(hyperparameters, - 'label_smoothing') else 0.0) - loss = workload.loss_fn( - label_batch=batch['targets'], - logits_batch=logits_batch, - label_smoothing=label_smoothing).mean() - - loss.backward() - optimizer_state['optimizer'].step() - optimizer_state['scheduler'].step() - - return (optimizer_state, current_param_container, new_model_state) diff --git a/target_setting_runs/jax_adamw.py b/target_setting_runs/jax_adamw.py index b2b7300e1..a91a513f1 100644 --- a/target_setting_runs/jax_adamw.py +++ b/target_setting_runs/jax_adamw.py @@ -1,23 +1,17 @@ """Submission file for an AdamW optimizer with warmup+cosine LR in Jax.""" -import functools -from typing import Dict, List, Tuple - from flax import jax_utils import jax -from jax import lax import jax.numpy as jnp import optax from algorithmic_efficiency import spec +from target_setting_runs import cosine_warmup from target_setting_runs.data_selection import \ data_selection # pylint: disable=unused-import - - -def get_batch_size(workload_name): - # Return the global batch size. - if workload_name == 'criteo1tb_dlrm': - return 524288 - +from target_setting_runs.get_batch_size import \ + get_batch_size # pylint: disable=unused-import +from target_setting_runs.jax_submission_base import \ + update_params # pylint: disable=unused-import def init_optimizer_state(workload: spec.Workload, @@ -30,18 +24,7 @@ def init_optimizer_state(workload: spec.Workload, del model_state del rng - # Create learning rate schedule. - warmup_fn = optax.linear_schedule( - init_value=0., - end_value=hyperparameters.learning_rate, - transition_steps=hyperparameters.warmup_steps) - cosine_steps = max(hyperparameters.num_steps - hyperparameters.warmup_steps, - 1) - cosine_fn = optax.cosine_decay_schedule( - init_value=hyperparameters.learning_rate, decay_steps=cosine_steps) - schedule_fn = optax.join_schedules( - schedules=[warmup_fn, cosine_fn], - boundaries=[hyperparameters.warmup_steps]) + lr_schedule_fn = cosine_warmup.jax_cosine_warmup(hyperparameters) # Create optimizer. params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), @@ -49,7 +32,7 @@ def init_optimizer_state(workload: spec.Workload, epsilon = ( hyperparameters.epsilon if hasattr(hyperparameters, 'epsilon') else 1e-8) opt_init_fn, opt_update_fn = optax.adamw( - learning_rate=schedule_fn, + learning_rate=lr_schedule_fn, b1=hyperparameters.beta1, b2=hyperparameters.beta2, eps=epsilon, @@ -57,71 +40,3 @@ def init_optimizer_state(workload: spec.Workload, optimizer_state = opt_init_fn(params_zeros_like) return jax_utils.replicate(optimizer_state), opt_update_fn - - -@functools.partial( - jax.pmap, - axis_name='batch', - in_axes=(None, None, 0, 0, 0, 0, 0, None), - static_broadcasted_argnums=(0, 1)) -def pmapped_train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - rng, - label_smoothing): - - def _loss_fn(params): - """Loss function used for training.""" - logits, new_model_state = workload.model_fn( - params, - batch, - model_state, - spec.ForwardPassMode.TRAIN, - rng, - dropout_rate=None, - aux_dropout_rate=None, - update_batch_norm=False) - loss = jnp.mean( - workload.loss_fn( - batch['targets'], logits, label_smoothing=label_smoothing)) - return loss, (new_model_state, logits) - - grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) - (new_model_state, _), grad = grad_fn(current_param_container) - grad = lax.pmean(grad, axis_name='batch') - updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, - current_param_container) - updated_params = optax.apply_updates(current_param_container, updates) - return new_model_state, new_optimizer_state, updated_params - - -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: - """Return (updated_optimizer_state, updated_params, updated_model_state).""" - del current_params_types - del loss_type - del eval_results - del global_step - - optimizer_state, opt_update_fn = optimizer_state - per_device_rngs = jax.random.split(rng, jax.local_device_count()) - label_smoothing = ( - hyperparameters.label_smoothing if hasattr(hyperparameters, - 'label_smoothing') else 0.0) - new_model_state, new_optimizer_state, new_params = pmapped_train_step( - workload, opt_update_fn, model_state, optimizer_state, - current_param_container, batch, per_device_rngs, label_smoothing) - - return (new_optimizer_state, opt_update_fn), new_params, new_model_state diff --git a/target_setting_runs/jax_nadamw.py b/target_setting_runs/jax_nadamw.py index 4f8441268..2dc57c514 100644 --- a/target_setting_runs/jax_nadamw.py +++ b/target_setting_runs/jax_nadamw.py @@ -1,4 +1,4 @@ -"""Submission file for a NAdamW optimizer in Jax.""" +"""Submission file for a NAdamW optimizer with warmup+cosine LR in Jax.""" from typing import Any, Callable, NamedTuple, Optional, Union @@ -10,44 +10,14 @@ from algorithmic_efficiency import spec - -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: - """Creates a NAdamW optimizer and a learning rate schedule.""" - del model_params - del model_state - del rng - - # Create learning rate schedule. - warmup_fn = optax.linear_schedule( - init_value=0., - end_value=hyperparameters.learning_rate, - transition_steps=hyperparameters.warmup_steps) - cosine_steps = max(hyperparameters.num_steps - hyperparameters.warmup_steps, - 1) - cosine_fn = optax.cosine_decay_schedule( - init_value=hyperparameters.learning_rate, decay_steps=cosine_steps) - schedule_fn = optax.join_schedules( - schedules=[warmup_fn, cosine_fn], - boundaries=[hyperparameters.warmup_steps]) - - # Create optimizer. - params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), - workload.param_shapes) - epsilon = ( - hyperparameters.epsilon if hasattr(hyperparameters, 'epsilon') else 1e-8) - opt_init_fn, opt_update_fn = nadamw( - learning_rate=schedule_fn, - b1=hyperparameters.beta1, - b2=hyperparameters.beta2, - eps=epsilon, - weight_decay=hyperparameters.l2) - optimizer_state = opt_init_fn(params_zeros_like) - - return jax_utils.replicate(optimizer_state), opt_update_fn +from algorithmic_efficiency import spec +from target_setting_runs import cosine_warmup +from target_setting_runs.data_selection import \ + data_selection # pylint: disable=unused-import +from target_setting_runs.get_batch_size import \ + get_batch_size # pylint: disable=unused-import +from target_setting_runs.jax_submission_base import \ + update_params # pylint: disable=unused-import # Forked from @@ -171,3 +141,31 @@ def scale_by_learning_rate(learning_rate, flip_sign=True): if callable(learning_rate): return optax.scale_by_schedule(lambda count: m * learning_rate(count)) return optax.scale(m * learning_rate) + + +def init_optimizer_state(workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState) -> spec.OptimizerState: + """Creates a NAdamW optimizer and a learning rate schedule.""" + del model_params + del model_state + del rng + + lr_schedule_fn = cosine_warmup.jax_cosine_warmup(hyperparameters) + + # Create optimizer. + params_zeros_like = jax.tree_map(lambda s: jnp.zeros(s.shape_tuple), + workload.param_shapes) + epsilon = ( + hyperparameters.epsilon if hasattr(hyperparameters, 'epsilon') else 1e-8) + opt_init_fn, opt_update_fn = nadamw( + learning_rate=lr_schedule_fn, + b1=hyperparameters.beta1, + b2=hyperparameters.beta2, + eps=epsilon, + weight_decay=hyperparameters.l2) + optimizer_state = opt_init_fn(params_zeros_like) + + return jax_utils.replicate(optimizer_state), opt_update_fn diff --git a/target_setting_runs/jax_nesterov.py b/target_setting_runs/jax_nesterov.py index 227fa5fe6..92e695b9e 100644 --- a/target_setting_runs/jax_nesterov.py +++ b/target_setting_runs/jax_nesterov.py @@ -8,6 +8,12 @@ import optax from algorithmic_efficiency import spec +from target_setting_runs.data_selection import \ + data_selection # pylint: disable=unused-import +from target_setting_runs.get_batch_size import \ + get_batch_size # pylint: disable=unused-import +from target_setting_runs.jax_submission_base import \ + update_params # pylint: disable=unused-import def init_optimizer_state(workload: spec.Workload, @@ -58,11 +64,13 @@ def create_lr_schedule_fn( # optimizer_lib/optimizers.py. def sgd(learning_rate, weight_decay, momentum=None, nesterov=False): r"""A customizable gradient descent optimizer. + NOTE: We apply weight decay **before** computing the momentum update. This is equivalent to applying WD after for heavy-ball momentum, but slightly different when using Nesterov accelleration. This is the same as how the Flax optimizers handle weight decay https://flax.readthedocs.io/en/latest/_modules/flax/optim/momentum.html. + Args: learning_rate: The learning rate. Expected as the positive learning rate, for example `\alpha` in `w -= \alpha * u` (as opposed to `\alpha`). diff --git a/target_setting_runs/criteo1tb/jax_submission.py b/target_setting_runs/jax_submission_base.py similarity index 73% rename from target_setting_runs/criteo1tb/jax_submission.py rename to target_setting_runs/jax_submission_base.py index 42a0cc0e9..19a0a19fc 100644 --- a/target_setting_runs/criteo1tb/jax_submission.py +++ b/target_setting_runs/jax_submission_base.py @@ -1,8 +1,6 @@ -""" -Jax submission for the target-setting run on Criteo1TB DLRM-Small with AdamW. -""" - +"""Update submission function in Jax.""" import functools +from multiprocessing.sharedctypes import Value from typing import Dict, List, Tuple import jax @@ -11,20 +9,15 @@ import optax from algorithmic_efficiency import spec -from target_setting_runs.data_selection import \ - data_selection # pylint: disable=unused-import -def get_batch_size(workload_name): - # Return the global batch size. - del workload_name - return 524288 +_GRAD_CLIP_EPS = 1e-6 @functools.partial( jax.pmap, axis_name='batch', - in_axes=(None, None, 0, 0, 0, 0, 0, None), + in_axes=(None, None, 0, 0, 0, 0, 0, None, None), static_broadcasted_argnums=(0, 1)) def pmapped_train_step(workload, opt_update_fn, @@ -33,6 +26,7 @@ def pmapped_train_step(workload, current_param_container, batch, rng, + grad_clip, label_smoothing): def _loss_fn(params): @@ -43,6 +37,7 @@ def _loss_fn(params): model_state, spec.ForwardPassMode.TRAIN, rng, + # There was no dropout rate tuning in the target setting runs. dropout_rate=None, aux_dropout_rate=None, update_batch_norm=False) @@ -54,6 +49,13 @@ def _loss_fn(params): grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) (new_model_state, _), grad = grad_fn(current_param_container) grad = lax.pmean(grad, axis_name='batch') + + if grad_clip is not None: + grad_norm = sum(jnp.sum(g ** 2) for g in jax.tree_leaves(grad)) + grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) + grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) + grad = jax.tree_map(lambda x: x * grad_scaling_factor, grad) + updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, current_param_container) updated_params = optax.apply_updates(current_param_container, updates) @@ -79,11 +81,17 @@ def update_params(workload: spec.Workload, optimizer_state, opt_update_fn = optimizer_state per_device_rngs = jax.random.split(rng, jax.local_device_count()) - label_smoothing = ( - hyperparameters.label_smoothing if hasattr(hyperparameters, - 'label_smoothing') else 0.0) + if hasattr(hyperparameters, 'label_smoothing'): + label_smoothing = hyperparameters.label_smoothing + else: + label_smoothing = 0.0 + if hasattr(hyperparameters, 'grad_clip'): + grad_clip = hyperparameters.grad_clip + else: + grad_clip = None new_model_state, new_optimizer_state, new_params = pmapped_train_step( workload, opt_update_fn, model_state, optimizer_state, - current_param_container, batch, per_device_rngs, label_smoothing) + current_param_container, batch, per_device_rngs, grad_clip, + label_smoothing) return (new_optimizer_state, opt_update_fn), new_params, new_model_state diff --git a/target_setting_runs/librispeech_conformer/__init__.py b/target_setting_runs/librispeech_conformer/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/target_setting_runs/librispeech_conformer/jax_submission.py b/target_setting_runs/librispeech_conformer/jax_submission.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/target_setting_runs/librispeech_conformer/pytorch_submission.py b/target_setting_runs/librispeech_conformer/pytorch_submission.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/target_setting_runs/librispeech_deepspeech/__init__.py b/target_setting_runs/librispeech_deepspeech/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/target_setting_runs/librispeech_deepspeech/jax_submission.py b/target_setting_runs/librispeech_deepspeech/jax_submission.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/target_setting_runs/librispeech_deepspeech/pytorch_submission.py b/target_setting_runs/librispeech_deepspeech/pytorch_submission.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/target_setting_runs/ogbg/__init__.py b/target_setting_runs/ogbg/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/target_setting_runs/ogbg/jax_submission.py b/target_setting_runs/ogbg/jax_submission.py deleted file mode 100644 index b896be5d2..000000000 --- a/target_setting_runs/ogbg/jax_submission.py +++ /dev/null @@ -1,92 +0,0 @@ -"""Jax submission for the target-setting run on OGBG with AdamW.""" - -from typing import Dict, List, Tuple - -import jax -from jax import lax -import jax.numpy as jnp -import optax - -from algorithmic_efficiency import spec -from target_setting_runs.data_selection import \ - data_selection # pylint: disable=unused-import -from target_setting_runs.jax_nadamw import \ - init_optimizer_state # pylint: disable=unused-import - - -def get_batch_size(workload_name): - # Return the global batch size. - del workload_name - return 512 - - -def train_step(workload, - opt_update_fn, - model_state, - optimizer_state, - current_param_container, - batch, - rng, - label_smoothing): - - def loss_fn(params): - logits_batch, new_model_state = workload.model_fn( - params, - batch, - model_state, - spec.ForwardPassMode.TRAIN, - rng, - dropout_rate=0.1, # Default. - aux_dropout_rate=None, - update_batch_norm=True) - mask_batch = batch['weights'] - per_example_losses = workload.loss_fn( - batch['targets'], - logits_batch, - mask_batch, - label_smoothing=label_smoothing) - mean_loss = ( - jnp.sum(jnp.where(mask_batch, per_example_losses, 0)) / - jnp.sum(mask_batch)) - return mean_loss, new_model_state - - grad_fn = jax.value_and_grad(loss_fn, has_aux=True) - (_, new_model_state), grad = grad_fn(current_param_container) - grad = lax.pmean(grad, axis_name='batch') - updates, new_optimizer_state = opt_update_fn( - grad, optimizer_state, current_param_container) - updated_params = optax.apply_updates(current_param_container, updates) - return new_model_state, new_optimizer_state, updated_params - - -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: - """Return (updated_optimizer_state, updated_params, updated_model_state).""" - del current_params_types - del loss_type - del eval_results - del global_step - - optimizer_state, opt_update_fn = optimizer_state - pmapped_train_step = jax.pmap( - train_step, - axis_name='batch', - in_axes=(None, None, 0, 0, 0, 0, 0, None), - static_broadcasted_argnums=(0, 1)) - dropout_rngs = jax.random.split(rng, jax.local_device_count()) - label_smoothing = ( - hyperparameters.label_smoothing if hasattr(hyperparameters, - 'label_smoothing') else 0.0) - new_model_state, new_optimizer_state, new_params = pmapped_train_step( - workload, opt_update_fn, model_state, optimizer_state, - current_param_container, batch, dropout_rngs, label_smoothing) - return (new_optimizer_state, opt_update_fn), new_params, new_model_state diff --git a/target_setting_runs/ogbg/pytorch_submission.py b/target_setting_runs/ogbg/pytorch_submission.py deleted file mode 100644 index 40f724044..000000000 --- a/target_setting_runs/ogbg/pytorch_submission.py +++ /dev/null @@ -1,63 +0,0 @@ -"""PyTorch submission for the target-setting run on OGBG with AdamW.""" - -from typing import Dict, List, Tuple - -import torch - -from algorithmic_efficiency import spec -from target_setting_runs.data_selection import \ - data_selection # pylint: disable=unused-import -from target_setting_runs.pytorch_nadamw import \ - init_optimizer_state # pylint: disable=unused-import - - -def get_batch_size(workload_name): - # Return the global batch size. - del workload_name - return 512 - - -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: - """Return (updated_optimizer_state, updated_params, updated_model_state).""" - del current_params_types - del loss_type - del eval_results - del global_step - - current_model = current_param_container - current_model.train() - optimizer_state['optimizer'].zero_grad() - - logits, new_model_state = workload.model_fn( - params=current_model, - augmented_and_preprocessed_input_batch=batch, - model_state=model_state, - mode=spec.ForwardPassMode.TRAIN, - rng=rng, - dropout_rate=0.1, # Default. - aux_dropout_rate=None, - update_batch_norm=True) - - mask = batch['weights'] - label_smoothing = ( - hyperparameters.label_smoothing if hasattr(hyperparameters, - 'label_smoothing') else 0.0) - per_example_losses = workload.loss_fn( - batch['targets'], logits, mask, label_smoothing=label_smoothing) - loss = torch.where(mask, per_example_losses, 0).sum() / mask.sum() - - loss.backward() - optimizer_state['optimizer'].step() - optimizer_state['scheduler'].step() - - return optimizer_state, current_param_container, new_model_state diff --git a/target_setting_runs/pytorch_adamw.py b/target_setting_runs/pytorch_adamw.py index 76582fbb2..2b15aa116 100644 --- a/target_setting_runs/pytorch_adamw.py +++ b/target_setting_runs/pytorch_adamw.py @@ -1,11 +1,15 @@ -"""Submission file for an AdamW optimizer in PyTorch.""" +"""Submission file for an AdamW optimizer with warmup+cosine LR in PyTorch.""" import torch -from torch.optim.lr_scheduler import CosineAnnealingLR -from torch.optim.lr_scheduler import LinearLR -from torch.optim.lr_scheduler import SequentialLR from algorithmic_efficiency import spec +from target_setting_runs import cosine_warmup +from target_setting_runs.data_selection import \ + data_selection # pylint: disable=unused-import +from target_setting_runs.get_batch_size import \ + get_batch_size # pylint: disable=unused-import +from target_setting_runs.pytorch_submission_base import \ + update_params # pylint: disable=unused-import def init_optimizer_state(workload: spec.Workload, @@ -30,19 +34,7 @@ def init_optimizer_state(workload: spec.Workload, weight_decay=hyperparameters.l2) } - scheduler1 = LinearLR( - optimizer_state['optimizer'], - start_factor=1e-10, - end_factor=1., - total_iters=hyperparameters.warmup_steps) - cosine_steps = max(hyperparameters.num_steps - hyperparameters.warmup_steps, - 1) - scheduler2 = CosineAnnealingLR( - optimizer_state['optimizer'], T_max=cosine_steps) - - optimizer_state['scheduler'] = SequentialLR( - optimizer_state['optimizer'], - schedulers=[scheduler1, scheduler2], - milestones=[hyperparameters.warmup_steps]) + optimizer_state['scheduler'] = cosine_warmup.pytorch_cosine_warmup( + hyperparameters, optimizer_state['optimizer']) return optimizer_state diff --git a/target_setting_runs/pytorch_nadamw.py b/target_setting_runs/pytorch_nadamw.py index 9471c2fe7..55c42e5bd 100644 --- a/target_setting_runs/pytorch_nadamw.py +++ b/target_setting_runs/pytorch_nadamw.py @@ -5,51 +5,15 @@ import torch from torch import Tensor -from torch.optim.lr_scheduler import CosineAnnealingLR -from torch.optim.lr_scheduler import LinearLR -from torch.optim.lr_scheduler import SequentialLR from algorithmic_efficiency import spec - - -def init_optimizer_state(workload: spec.Workload, - model_params: spec.ParameterContainer, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - rng: spec.RandomState) -> spec.OptimizerState: - """Creates a NAdamW optimizer and a learning rate schedule.""" - del workload - del model_state - del rng - - epsilon = ( - hyperparameters.epsilon if hasattr(hyperparameters, 'epsilon') else 1e-8) - optimizer_state = { - 'optimizer': - NAdamW( - model_params.parameters(), - lr=hyperparameters.learning_rate, - betas=(hyperparameters.beta1, hyperparameters.beta2), - eps=epsilon, - weight_decay=hyperparameters.l2) - } - - scheduler1 = LinearLR( - optimizer_state['optimizer'], - start_factor=1e-10, - end_factor=1., - total_iters=hyperparameters.warmup_steps) - cosine_steps = max(hyperparameters.num_steps - hyperparameters.warmup_steps, - 1) - scheduler2 = CosineAnnealingLR( - optimizer_state['optimizer'], T_max=cosine_steps) - - optimizer_state['scheduler'] = SequentialLR( - optimizer_state['optimizer'], - schedulers=[scheduler1, scheduler2], - milestones=[hyperparameters.warmup_steps]) - - return optimizer_state +from target_setting_runs import cosine_warmup +from target_setting_runs.data_selection import \ + data_selection # pylint: disable=unused-import +from target_setting_runs.get_batch_size import \ + get_batch_size # pylint: disable=unused-import +from target_setting_runs.pytorch_submission_base import \ + update_params # pylint: disable=unused-import # Modified from github.com/pytorch/pytorch/blob/v1.12.1/torch/optim/adamw.py @@ -214,3 +178,30 @@ def nadamw(params: List[Tensor], denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) param.addcdiv_(exp_avg_hat, denom, value=-step_size) + + +def init_optimizer_state(workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState) -> spec.OptimizerState: + """Creates a NAdamW optimizer and a learning rate schedule.""" + del workload + del model_state + del rng + + epsilon = ( + hyperparameters.epsilon if hasattr(hyperparameters, 'epsilon') else 1e-8) + optimizer_state = { + 'optimizer': + NAdamW( + model_params.parameters(), + lr=hyperparameters.learning_rate, + betas=(hyperparameters.beta1, hyperparameters.beta2), + eps=epsilon, + weight_decay=hyperparameters.l2) + } + + optimizer_state['scheduler'] = cosine_warmup.pytorch_cosine_warmup( + hyperparameters, optimizer_state['optimizer']) + return optimizer_state diff --git a/target_setting_runs/pytorch_nesterov.py b/target_setting_runs/pytorch_nesterov.py index f7ade1ef5..d1b0a04c0 100644 --- a/target_setting_runs/pytorch_nesterov.py +++ b/target_setting_runs/pytorch_nesterov.py @@ -4,7 +4,13 @@ from torch.optim.lr_scheduler import LambdaLR from algorithmic_efficiency import spec +from target_setting_runs.data_selection import \ + data_selection # pylint: disable=unused-import from target_setting_runs.jax_nesterov import create_lr_schedule_fn +from target_setting_runs.get_batch_size import \ + get_batch_size # pylint: disable=unused-import +from target_setting_runs.pytorch_submission_base import \ + update_params # pylint: disable=unused-import def init_optimizer_state(workload: spec.Workload, diff --git a/target_setting_runs/fastmri/pytorch_submission.py b/target_setting_runs/pytorch_submission_base.py similarity index 86% rename from target_setting_runs/fastmri/pytorch_submission.py rename to target_setting_runs/pytorch_submission_base.py index dcaa125bf..483de5b60 100644 --- a/target_setting_runs/fastmri/pytorch_submission.py +++ b/target_setting_runs/pytorch_submission_base.py @@ -1,13 +1,8 @@ -"""PyTorch submission for the target-setting run on FastMRI with NAdamW.""" +"""Batch size and update submission functions in PyTorch.""" from typing import Dict, List, Tuple from algorithmic_efficiency import spec -from target_setting_runs.data_selection import \ - data_selection # pylint: disable=unused-import -from target_setting_runs.pytorch_nadamw import \ - init_optimizer_state # pylint: disable=unused-import - def get_batch_size(workload_name): # Return the global batch size. diff --git a/target_setting_runs/wmt/__init__.py b/target_setting_runs/wmt/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/target_setting_runs/wmt/jax_submission.py b/target_setting_runs/wmt/jax_submission.py deleted file mode 100644 index dba2ae6ff..000000000 --- a/target_setting_runs/wmt/jax_submission.py +++ /dev/null @@ -1,94 +0,0 @@ -"""Jax submission for the target-setting run on WMT with AdamW.""" - -import functools -from typing import Dict, List, Tuple - -import jax -import jax.numpy as jnp -import optax - -from algorithmic_efficiency import spec -from target_setting_runs.data_selection import \ - data_selection # pylint: disable=unused-import -from target_setting_runs.jax_adamw import \ - init_optimizer_state # pylint: disable=unused-import - - -def get_batch_size(workload_name): - # Return the global batch size. - del workload_name - return 256 - - -@functools.partial( - jax.pmap, - in_axes=(None, None, 0, 0, 0, 0, None), - axis_name='batch', - static_broadcasted_argnums=(0, 1)) -def pmapped_train_step(workload, - opt_update_fn, - optimizer_state, - current_param_container, - batch, - dropout_rng, - label_smoothing): - """Perform a single training step.""" - - def _loss_fn(params): - """Loss function used for training.""" - logits, _ = workload.model_fn( - params, - batch, - model_state=None, - mode=spec.ForwardPassMode.TRAIN, - rng=dropout_rng, - dropout_rate=0.1, # Default. - aux_dropout_rate=0.1, # Default. - update_batch_norm=False) - targets = batch['targets'] - weights = jnp.where(targets > 0, 1.0, 0.0) - loss = (workload.loss_fn(targets, logits, label_smoothing=label_smoothing) * - weights).sum() / weights.sum() - return loss - - grad_fn = jax.value_and_grad(_loss_fn) - _, grad = grad_fn(current_param_container) - grad = jax.lax.pmean(grad, axis_name='batch') - updates, new_optimizer_state = opt_update_fn( - grad, optimizer_state, current_param_container) - updated_params = optax.apply_updates(current_param_container, updates) - return new_optimizer_state, updated_params - - -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: - """Return (updated_optimizer_state, updated_params, updated_model_state).""" - del current_params_types - del eval_results - del global_step - del model_state - del loss_type - - optimizer_state, opt_update_fn = optimizer_state - dropout_rngs = jax.random.split(rng, jax.local_device_count()) - label_smoothing = ( - hyperparameters.label_smoothing if hasattr(hyperparameters, - 'label_smoothing') else 0.0) - new_optimizer_state, updated_params = pmapped_train_step( - workload, - opt_update_fn, - optimizer_state, - current_param_container, - batch, - dropout_rngs, - label_smoothing) - return (new_optimizer_state, opt_update_fn), updated_params, None diff --git a/target_setting_runs/wmt/pytorch_submission.py b/target_setting_runs/wmt/pytorch_submission.py deleted file mode 100644 index d828569f9..000000000 --- a/target_setting_runs/wmt/pytorch_submission.py +++ /dev/null @@ -1,63 +0,0 @@ -"""PyTorch submission for the target-setting run on WMT with AdamW.""" - -from typing import Dict, List, Tuple - -import torch - -from algorithmic_efficiency import spec -from target_setting_runs.data_selection import \ - data_selection # pylint: disable=unused-import -from target_setting_runs.pytorch_adamw import \ - init_optimizer_state # pylint: disable=unused-import - - -def get_batch_size(workload_name): - # Return the global batch size. - del workload_name - return 256 - - -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: - """Return (updated_optimizer_state, updated_params, updated_model_state).""" - del current_params_types - del eval_results - del loss_type - del global_step - - current_model = current_param_container - current_model.train() - optimizer_state['optimizer'].zero_grad() - - logits, _ = workload.model_fn( - params=current_model, - augmented_and_preprocessed_input_batch=batch, - model_state=model_state, - mode=spec.ForwardPassMode.TRAIN, - rng=rng, - dropout_rate=0.1, # Default. - aux_dropout_rate=0.1, # Default. - update_batch_norm=False) - - targets = batch['targets'] - weights = torch.where(targets > 0, 1.0, 0.0) - label_smoothing = ( - hyperparameters.label_smoothing if hasattr(hyperparameters, - 'label_smoothing') else 0.0) - loss = (workload.loss_fn(targets, logits, label_smoothing=label_smoothing) * - weights).sum() / weights.sum() - loss.backward() - - optimizer_state['optimizer'].step() - optimizer_state['scheduler'].step() - - return (optimizer_state, current_param_container, None)