diff --git a/README.md b/README.md index ada979be7..67342ad75 100644 --- a/README.md +++ b/README.md @@ -152,6 +152,7 @@ Docker is the easiest way to enable PyTorch/JAX GPU support on Linux since only python3 submission_runner.py \ --framework=jax \ --workload=mnist \ + --experiment_dir=/home/znado \ --submission_path=reference_submissions/mnist/mnist_jax/submission.py \ --tuning_search_space=reference_submissions/mnist/tuning_search_space.json ``` @@ -162,6 +163,7 @@ python3 submission_runner.py \ python3 submission_runner.py \ --framework=pytorch \ --workload=mnist \ + --experiment_dir=/home/znado \ --submission_path=reference_submissions/mnist/mnist_pytorch/submission.py \ --tuning_search_space=reference_submissions/mnist/tuning_search_space.json ``` @@ -174,7 +176,10 @@ To do so, simply replace `python3` by torchrun --standalone --nnodes=1 --nproc_per_node=N_GPUS ``` -where `N_GPUS` is the number of available GPUs on the node. +where `N_GPUS` is the number of available GPUs on the node. To only see output from the first process, you can run the following to redirect the output from processes 1-7 to a log file: +```bash +torchrun --redirects 1:0,2:0,3:0,4:0,5:0,6:0,7:0 --standalone --nnodes=1 --nproc_per_node=8 + ``` ## Rules diff --git a/RULES.md b/RULES.md index b607bed90..b293e60be 100644 --- a/RULES.md +++ b/RULES.md @@ -107,7 +107,11 @@ def model_fn( ###### Loss function ```python -def loss_fn(label_batch, logits_output_batch) -> 1d array of losses per example # differentiable +def loss_fn( + label_batch: Union[Tuple[Tensor, Tensor], Tensor], + logits_batch: Union[Tuple[Tensor, Tensor], Tensor], + mask_batch: Optional[Tensor] = None, + label_smoothing: float = 0.0) -> 1d array of losses per example # differentiable ``` - Unlike in the [Model Track](#model-track), we will specify the loss function name in order to let training algorithms depend on the loss function. It will be one of {**mean squared error**, **cross-entropy**, **CTC**, or **L1 reconstruction error**}. diff --git a/algorithmic_efficiency/logger_utils.py b/algorithmic_efficiency/logger_utils.py index a35b1487c..61377807d 100644 --- a/algorithmic_efficiency/logger_utils.py +++ b/algorithmic_efficiency/logger_utils.py @@ -131,8 +131,8 @@ def _get_git_commit_hash() -> str: def _get_git_branch() -> str: - return subprocess.check_output(['git', 'branch', - '--show-current']).decode('ascii').strip() + return subprocess.check_output(['git', 'rev-parse', '--abbrev-ref', + 'HEAD']).decode('ascii').strip() def _get_cpu_model_name() -> str: diff --git a/algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py b/algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py index ed47ae3d4..d23ead15d 100644 --- a/algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py +++ b/algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py @@ -177,21 +177,26 @@ def model_fn( augmented_and_preprocessed_input_batch['inputs'], update_batch_norm=update_batch_norm, mutable=False) - return logits, None + return logits, model_state # Does NOT apply regularization, which is left to the submitter to do in # `update_params`. def loss_fn(self, label_batch: spec.Tensor, logits_batch: spec.Tensor, + mask_batch: Optional[Tensor] = None, label_smoothing: float = 0.0) -> spec.Tensor: # differentiable one_hot_targets = jax.nn.one_hot(label_batch, 10) smoothed_targets = optax.smooth_labels(one_hot_targets, label_smoothing) - return -jnp.sum(smoothed_targets * nn.log_softmax(logits_batch), axis=-1) + losses = -jnp.sum(smoothed_targets * nn.log_softmax(logits_batch), axis=-1) + # mask_batch is assumed to be shape [batch] + if mask_batch is not None: + losses *= mask_batch + return losses def _compute_metrics(self, logits, labels): loss = jnp.sum(self.loss_fn(labels, logits)) - # not accuracy, but nr. of correct predictions + # Number of correct predictions. accuracy = jnp.sum(jnp.argmax(logits, -1) == labels) metrics = { 'loss': loss, diff --git a/algorithmic_efficiency/workloads/cifar/cifar_pytorch/workload.py b/algorithmic_efficiency/workloads/cifar/cifar_pytorch/workload.py index 9a8bb5fe0..03214bc31 100644 --- a/algorithmic_efficiency/workloads/cifar/cifar_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/cifar/cifar_pytorch/workload.py @@ -199,12 +199,17 @@ def output_activation_fn(self, def loss_fn(self, label_batch: spec.Tensor, logits_batch: spec.Tensor, + mask_batch: Optional[Tensor] = None, label_smoothing: float = 0.0) -> spec.Tensor: # differentiable - return F.cross_entropy( + losses = F.cross_entropy( logits_batch, label_batch, reduction='none', label_smoothing=label_smoothing) + # mask_batch is assumed to be shape [batch]. + if mask_batch is not None: + losses *= mask_batch + return losses def _eval_metric(self, logits, labels): """Return the mean accuracy and loss as a dict.""" diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py index d39d6962f..a153bb6ad 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/workload.py @@ -33,17 +33,11 @@ def loss_fn( logits_batch: spec.Tensor, mask_batch: Optional[spec.Tensor] = None, label_smoothing: float = 0.0) -> spec.Tensor: - # TODO(znado): confirm that we do not want to tune label smoothing here. - del label_smoothing per_example_losses = metrics.per_example_sigmoid_binary_cross_entropy( logits=logits_batch, targets=label_batch) if mask_batch is not None: - weighted_losses = per_example_losses * mask_batch - normalization = mask_batch.sum() - else: - weighted_losses = per_example_losses - normalization = label_batch.shape[0] - return jnp.sum(weighted_losses, axis=-1) / normalization + per_example_losses *= mask_batch + return per_example_losses def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: self._model = models.DlrmSmall( diff --git a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py index 9663423b0..860bc30ed 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/workload.py @@ -41,12 +41,8 @@ def loss_fn(self, per_example_losses = metrics.per_example_sigmoid_binary_cross_entropy( logits=logits_batch, targets=label_batch) if mask_batch is not None: - weighted_losses = per_example_losses * mask_batch - normalization = mask_batch.sum() - else: - weighted_losses = per_example_losses - normalization = label_batch.shape[0] - return torch.sum(weighted_losses, dim=-1) / normalization + per_example_losses *= mask_batch + return per_example_losses def _eval_metric(self, logits: spec.Tensor, targets: spec.Tensor) -> Dict[str, int]: @@ -81,9 +77,13 @@ def model_fn( model_state: spec.ModelAuxiliaryState, mode: spec.ForwardPassMode, rng: spec.RandomState, + dropout_rate: Optional[float], + aux_dropout_rate: Optional[float], update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del model_state del rng + del dropout_rate + del aux_dropout_rate del update_batch_norm model = params @@ -188,9 +188,11 @@ def _eval_batch(self, params, batch): model_state=None, mode=spec.ForwardPassMode.EVAL, rng=None, + dropout_rate=None, + aux_dropout_rate=None, update_batch_norm=False) per_example_losses = metrics.per_example_sigmoid_binary_cross_entropy( logits, batch['targets']) - batch_loss_numerator = torch.sum(per_example_losses) - batch_loss_denominator = torch.sum(batch['weights']) + batch_loss_numerator = torch.sum(per_example_losses).cpu().numpy() + batch_loss_denominator = torch.sum(batch['weights']).cpu().numpy() return batch_loss_numerator, batch_loss_denominator diff --git a/algorithmic_efficiency/workloads/criteo1tb/workload.py b/algorithmic_efficiency/workloads/criteo1tb/workload.py index 4524c9a8d..cea147828 100644 --- a/algorithmic_efficiency/workloads/criteo1tb/workload.py +++ b/algorithmic_efficiency/workloads/criteo1tb/workload.py @@ -28,7 +28,7 @@ def has_reached_goal(self, eval_result: float) -> bool: @property def target_value(self): - return 0.1255 + return 0.12422498 @property def loss_type(self): @@ -64,7 +64,7 @@ def max_allowed_runtime_sec(self): @property def eval_period_time_sec(self): - return 10 * 60 + return 24 * 60 def output_activation_fn(self, logits_batch: spec.Tensor, diff --git a/algorithmic_efficiency/workloads/fastmri/fastmri_jax/workload.py b/algorithmic_efficiency/workloads/fastmri/fastmri_jax/workload.py index 3f17e7578..25ff81bf2 100644 --- a/algorithmic_efficiency/workloads/fastmri/fastmri_jax/workload.py +++ b/algorithmic_efficiency/workloads/fastmri/fastmri_jax/workload.py @@ -71,10 +71,19 @@ def output_activation_fn(self, # Does NOT apply regularization, which is left to the submitter to do in # `update_params`. - def loss_fn(self, label_batch: spec.Tensor, - outputs_batch: spec.Tensor) -> spec.Tensor: # differentiable - return jnp.abs(outputs_batch - - label_batch).mean(axis=tuple(range(1, outputs_batch.ndim))) + def loss_fn(self, + label_batch: spec.Tensor, + logits_batch: spec.Tensor, + mask_batch: Optional[spec.Tensor] = None, + label_smoothing: float = 0.0) -> spec.Tensor: # differentiable + del label_smoothing + losses = jnp.sum( + jnp.abs(outputs_batch - label_batch), + axis=tuple(range(1, outputs_batch.ndim))) + # mask_batch is assumed to be shape [batch]. + if mask_batch is not None: + losses *= mask_batch + return losses @functools.partial( jax.pmap, diff --git a/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/workload.py b/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/workload.py index 185a65c29..75b993937 100644 --- a/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/fastmri/fastmri_pytorch/workload.py @@ -173,12 +173,15 @@ def output_activation_fn(self, def loss_fn(self, label_batch: spec.Tensor, logits_batch: spec.Tensor, - mask_batch: spec.Tensor = None, + mask_batch: Optional[spec.Tensor] = None, label_smoothing: float = 0.0) -> spec.Tensor: # differentiable - del mask_batch del label_smoothing - return F.l1_loss( - logits_batch, label_batch, reduction='none').mean(dim=(1, 2)) + losses = F.l1_loss( + logits_batch, label_batch, reduction='none').sum(dim=(1, 2)) + # mask_batch is assumed to be shape [batch]. + if mask_batch is not None: + losses *= mask_batch + return losses def _eval_model(self, params, batch, rng): """Return the SSIM and loss as a dict.""" diff --git a/algorithmic_efficiency/workloads/fastmri/workload.py b/algorithmic_efficiency/workloads/fastmri/workload.py index 2829fb76a..f414b9451 100644 --- a/algorithmic_efficiency/workloads/fastmri/workload.py +++ b/algorithmic_efficiency/workloads/fastmri/workload.py @@ -18,7 +18,7 @@ def has_reached_goal(self, eval_result: float) -> bool: @property def target_value(self): - return 0.70 + return 0.735102235 @property def loss_type(self): @@ -62,7 +62,7 @@ def max_allowed_runtime_sec(self): @property def eval_period_time_sec(self): - return 6000 # 100 mins + return 80 @property def param_shapes(self): diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/models.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/models.py index 77803806e..41e003e89 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/models.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/models.py @@ -4,7 +4,7 @@ https://github.com/google/flax/blob/main/examples/imagenet/models.py """ -from functools import partial +import functools from typing import Any, Callable, Tuple from flax import linen as nn @@ -79,8 +79,8 @@ class ResNet(nn.Module): @nn.compact def __call__(self, x, update_batch_norm: bool = True): - conv = partial(nn.Conv, use_bias=False, dtype=self.dtype) - norm = partial( + conv = functools.partial(nn.Conv, use_bias=False, dtype=self.dtype) + norm = functools.partial( nn.BatchNorm, use_running_average=not update_batch_norm, momentum=0.9, @@ -112,16 +112,18 @@ def __call__(self, x, update_batch_norm: bool = True): return x -ResNet18 = partial(ResNet, stage_sizes=(2, 2, 2, 2), block_cls=ResNetBlock) -ResNet34 = partial(ResNet, stage_sizes=(3, 4, 6, 3), block_cls=ResNetBlock) -ResNet50 = partial( +ResNet18 = functools.partial( + ResNet, stage_sizes=(2, 2, 2, 2), block_cls=ResNetBlock) +ResNet34 = functools.partial( + ResNet, stage_sizes=(3, 4, 6, 3), block_cls=ResNetBlock) +ResNet50 = functools.partial( ResNet, stage_sizes=(3, 4, 6, 3), block_cls=BottleneckResNetBlock) -ResNet101 = partial( +ResNet101 = functools.partial( ResNet, stage_sizes=(3, 4, 23, 3), block_cls=BottleneckResNetBlock) -ResNet152 = partial( +ResNet152 = functools.partial( ResNet, stage_sizes=(3, 8, 36, 3), block_cls=BottleneckResNetBlock) -ResNet200 = partial( +ResNet200 = functools.partial( ResNet, stage_sizes=(3, 24, 36, 3), block_cls=BottleneckResNetBlock) # Used for testing only. -_ResNet1 = partial(ResNet, stage_sizes=(1,), block_cls=ResNetBlock) +_ResNet1 = functools.partial(ResNet, stage_sizes=(1,), block_cls=ResNetBlock) diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py index a04537318..27cd074df 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py @@ -86,18 +86,14 @@ def model_params_types(self): self._param_shapes.unfreeze()) return self._param_types - def initialized(self, key, model): - input_shape = (1, 224, 224, 3) - variables = jax.jit(model.init)({'params': key}, - jnp.ones(input_shape, model.dtype)) - model_state, params = variables.pop('params') - return params, model_state - def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState: model_cls = getattr(models, 'ResNet50') model = model_cls(num_classes=self._num_classes, dtype=jnp.float32) self._model = model - params, model_state = self.initialized(rng, model) + input_shape = (1, 224, 224, 3) + variables = jax.jit(model.init)({'params': rng}, + jnp.ones(input_shape, model.dtype)) + model_state, params = variables.pop('params') self._param_shapes = jax.tree_map(lambda x: spec.ShapeTuple(x.shape), params) model_state = jax_utils.replicate(model_state) @@ -158,13 +154,14 @@ def model_fn( augmented_and_preprocessed_input_batch['inputs'], update_batch_norm=update_batch_norm, mutable=False) - return logits, None + return logits, model_state # Does NOT apply regularization, which is left to the submitter to do in # `update_params`. def loss_fn(self, label_batch: spec.Tensor, logits_batch: spec.Tensor, + mask_batch: Optional[Tensor] = None, label_smoothing: float = 0.0) -> spec.Tensor: # differentiable """Cross Entropy Loss""" if label_batch.shape[-1] != self._num_classes: @@ -173,8 +170,11 @@ def loss_fn(self, else: one_hot_labels = label_batch smoothed_labels = optax.smooth_labels(one_hot_labels, label_smoothing) - return optax.softmax_cross_entropy( - logits=logits_batch, labels=smoothed_labels) + losses = -jnp.sum(labels * jax.nn.log_softmax(logits, axis=-1), axis=-1) + # mask_batch is assumed to be shape [batch]. + if mask_batch is not None: + losses *= mask_batch + return losses def _compute_metrics(self, logits, labels): loss = jnp.sum(self.loss_fn(labels, logits)) diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/workload.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/workload.py index ae988bbce..4e518ef7e 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_pytorch/workload.py @@ -213,12 +213,17 @@ def output_activation_fn(self, def loss_fn(self, label_batch: spec.Tensor, logits_batch: spec.Tensor, + mask_batch: Optional[Tensor] = None, label_smoothing: float = 0.0) -> spec.Tensor: # differentiable - return F.cross_entropy( + losses = F.cross_entropy( logits_batch, label_batch, reduction='none', label_smoothing=label_smoothing) + # mask_batch is assumed to be shape [batch]. + if mask_batch is not None: + losses *= mask_batch + return losses def _eval_metric(self, logits, labels): """Return the mean accuracy and loss as a dict.""" diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/workload.py b/algorithmic_efficiency/workloads/imagenet_resnet/workload.py index ada9c9dab..8199331df 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/workload.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/workload.py @@ -19,7 +19,7 @@ def has_reached_goal(self, eval_result: float) -> bool: @property def target_value(self): - return 0.76 + return 0.771850005 @property def loss_type(self): @@ -73,7 +73,7 @@ def max_allowed_runtime_sec(self): @property def eval_period_time_sec(self): - return 6000 # 100 mins + return 510 # 8.5 minutes. @property def param_shapes(self): diff --git a/algorithmic_efficiency/workloads/imagenet_vit/workload.py b/algorithmic_efficiency/workloads/imagenet_vit/workload.py index bdc4d8496..09752129d 100644 --- a/algorithmic_efficiency/workloads/imagenet_vit/workload.py +++ b/algorithmic_efficiency/workloads/imagenet_vit/workload.py @@ -52,7 +52,7 @@ class BaseImagenetVitWorkload(BaseImagenetResNetWorkload): @property def target_value(self): - return 0.76 + return 0.77171 @property def max_allowed_runtime_sec(self): @@ -60,7 +60,7 @@ def max_allowed_runtime_sec(self): @property def eval_period_time_sec(self): - return 6000 # 100 mins + return 7 * 60 # 7 mins. def _build_dataset(self, data_rng: spec.RandomState, diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py index 3edbcea71..42440b8bc 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py @@ -400,7 +400,8 @@ def __call__(self, inputs, paddings, train): else: attention_residual_dropout_rate = config.attention_residual_dropout_rate result = nn.Dropout( - rate=attention_residual_dropout_rate, deterministic=not train)(result) + rate=attention_residual_dropout_rate, deterministic=not train)( + result) return result @@ -540,7 +541,8 @@ def __call__(self, inputs, input_paddings, train): else: conv_residual_dropout_rate = config.conv_residual_dropout_rate inputs = nn.Dropout( - rate=conv_residual_dropout_rate, deterministic=not train)(inputs) + rate=conv_residual_dropout_rate, deterministic=not train)( + inputs) return inputs diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py index b7e1a2d05..5c7c7e5df 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py @@ -105,7 +105,7 @@ def _model_fn( input_paddings, train=False, mutable=False) - return (logits, logit_paddings), None + return (logits, logit_paddings), model_state @property def model_params_types(self): diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py index 4242bd16d..73f29e7a5 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/workload.py @@ -23,8 +23,9 @@ MAX_INPUT_LENGTH = 320000 -def _maybe_update_model_dropout( - model, residual_dropout_rate, input_dropout_rate): +def _maybe_update_model_dropout(model, + residual_dropout_rate, + input_dropout_rate): for child in list(model.modules()): # Residual dropout. if (isinstance(child, conformer_model.MultiHeadedSelfAttention) and diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/workload.py b/algorithmic_efficiency/workloads/librispeech_conformer/workload.py index 97b7d15af..3a4082e0e 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/workload.py @@ -25,7 +25,7 @@ def has_reached_goal(self, eval_result: float) -> bool: @property def target_value(self): - return 0.109 + return 0.08420191 @property def loss_type(self): @@ -61,7 +61,7 @@ def max_allowed_runtime_sec(self): @property def eval_period_time_sec(self): - return 2500 + return 11 * 60 @property def param_shapes(self): diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/models.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/models.py index 2cb583f89..d1bf11ec5 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/models.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/models.py @@ -106,7 +106,8 @@ def __call__(self, inputs, output_paddings, train): else: input_dropout_rate = config.input_dropout_rate outputs = nn.Dropout( - rate=input_dropout_rate, deterministic=not train)(outputs) + rate=input_dropout_rate, deterministic=not train)( + outputs) return outputs, output_paddings diff --git a/algorithmic_efficiency/workloads/mnist/mnist_jax/workload.py b/algorithmic_efficiency/workloads/mnist/mnist_jax/workload.py index 271dbd7a1..13f958685 100644 --- a/algorithmic_efficiency/workloads/mnist/mnist_jax/workload.py +++ b/algorithmic_efficiency/workloads/mnist/mnist_jax/workload.py @@ -157,10 +157,15 @@ def model_fn( def loss_fn(self, label_batch: spec.Tensor, logits_batch: spec.Tensor, + mask_batch: Optional[Tensor] = None, label_smoothing: float = 0.0) -> spec.Tensor: # differentiable one_hot_targets = jax.nn.one_hot(label_batch, 10) smoothed_targets = optax.smooth_labels(one_hot_targets, label_smoothing) - return -jnp.sum(smoothed_targets * nn.log_softmax(logits_batch), axis=-1) + losses = -jnp.sum(smoothed_targets * nn.log_softmax(logits_batch), axis=-1) + # mask_batch is assumed to be shape [batch]. + if mask_batch is not None: + losses *= mask_batch + return losses @functools.partial( jax.pmap, diff --git a/algorithmic_efficiency/workloads/mnist/mnist_pytorch/workload.py b/algorithmic_efficiency/workloads/mnist/mnist_pytorch/workload.py index 0ff56a803..c9fe83485 100644 --- a/algorithmic_efficiency/workloads/mnist/mnist_pytorch/workload.py +++ b/algorithmic_efficiency/workloads/mnist/mnist_pytorch/workload.py @@ -177,12 +177,17 @@ def output_activation_fn(self, def loss_fn(self, label_batch: spec.Tensor, logits_batch: spec.Tensor, + mask_batch: Optional[Tensor] = None, label_smoothing: float = 0.0) -> spec.Tensor: # differentiable - return F.cross_entropy( + losses = F.cross_entropy( logits_batch, label_batch, reduction='none', label_smoothing=label_smoothing) + # mask_batch is assumed to be shape [batch]. + if mask_batch is not None: + losses *= mask_batch + return losses def _eval_model( self, diff --git a/algorithmic_efficiency/workloads/ogbg/ogbg_jax/models.py b/algorithmic_efficiency/workloads/ogbg/ogbg_jax/models.py index 940040c32..0d8b7d4ba 100644 --- a/algorithmic_efficiency/workloads/ogbg/ogbg_jax/models.py +++ b/algorithmic_efficiency/workloads/ogbg/ogbg_jax/models.py @@ -1,6 +1,6 @@ # Forked from the init2winit implementation here # https://github.com/google/init2winit/blob/master/init2winit/model_lib/gnn.py. -from typing import Tuple +from typing import Optional, Tuple from flax import linen as nn import jax.numpy as jnp @@ -39,12 +39,17 @@ class GNN(nn.Module): num_outputs: int latent_dim: int = 256 hidden_dims: Tuple[int] = (256,) - dropout_rate: float = 0.1 + # If None, defaults to 0.1. + dropout_rate: Optional[float] = 0.1 num_message_passing_steps: int = 5 @nn.compact def __call__(self, graph, train): - dropout = nn.Dropout(rate=self.dropout_rate, deterministic=not train) + if self.dropout_rate is None: + dropout_rate = 0.1 + else: + dropout_rate = self.dropout_rate + dropout = nn.Dropout(rate=dropout_rate, deterministic=not train) graph = graph._replace( globals=jnp.zeros([graph.n_node.shape[0], self.num_outputs])) diff --git a/algorithmic_efficiency/workloads/ogbg/workload.py b/algorithmic_efficiency/workloads/ogbg/workload.py index 80f2438a9..5723ba870 100644 --- a/algorithmic_efficiency/workloads/ogbg/workload.py +++ b/algorithmic_efficiency/workloads/ogbg/workload.py @@ -25,9 +25,7 @@ def has_reached_goal(self, eval_result: float) -> bool: @property def target_value(self): - # From Flax example - # https://tensorboard.dev/experiment/AAJqfvgSRJaA1MBkc0jMWQ/#scalars. - return 0.24 + return 0.28380056 @property def loss_type(self): @@ -63,7 +61,7 @@ def max_allowed_runtime_sec(self): @property def eval_period_time_sec(self): - return 120 + return 4 * 60 @property def param_shapes(self): diff --git a/algorithmic_efficiency/workloads/wmt/workload.py b/algorithmic_efficiency/workloads/wmt/workload.py index 58d940d3f..b04d1808a 100644 --- a/algorithmic_efficiency/workloads/wmt/workload.py +++ b/algorithmic_efficiency/workloads/wmt/workload.py @@ -33,7 +33,7 @@ def has_reached_goal(self, eval_result: float) -> bool: @property def target_value(self): - return 25 + return 30.8788074 @property def loss_type(self): @@ -72,7 +72,7 @@ def max_allowed_runtime_sec(self): @property def eval_period_time_sec(self): - return 2400 + return 14 * 60 def _build_input_queue(self, data_rng: jax.random.PRNGKey, @@ -211,6 +211,10 @@ def loss_fn( self, label_batch: spec.Tensor, # Dense (not one-hot) labels. logits_batch: spec.Tensor, + mask_batch: Optional[Tensor] = None, label_smoothing: float = 0.0) -> spec.Tensor: return self.compute_weighted_cross_entropy( - logits_batch, label_batch, label_smoothing=label_smoothing) + logits_batch, + label_batch, + weights=mask_batch, + label_smoothing=label_smoothing) diff --git a/submission_runner.py b/submission_runner.py index f0a36a07b..08aad2b6e 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -83,12 +83,12 @@ flags.DEFINE_string( 'submission_path', - 'reference_submissions/mnist/mnist_jax/submission.py', + None, 'The relative path of the Python file containing submission functions. ' 'NOTE: the submission dir must have an __init__.py file!') flags.DEFINE_string( 'workload', - 'mnist', + None, help=f'The name of the workload to run.\n Choices: {list(WORKLOADS.keys())}' ) flags.DEFINE_enum( @@ -98,10 +98,10 @@ help='Which tuning ruleset to use.') flags.DEFINE_string( 'tuning_search_space', - 'reference_submissions/mnist/tuning_search_space.json', + None, 'The path to the JSON file describing the external tuning search space.') flags.DEFINE_integer('num_tuning_trials', - 20, + 1, 'The number of external hyperparameter trials to run.') flags.DEFINE_string('data_dir', '~/tensorflow_datasets/', 'Dataset location.') flags.DEFINE_string('imagenet_v2_data_dir', @@ -123,7 +123,7 @@ 'The root directory to store all experiments. ' 'It is not required, but the directory should have ' 'an absolute path rather than a relative path.') -flags.DEFINE_string('experiment_name', '', 'Name of the experiment.') +flags.DEFINE_string('experiment_name', None, 'Name of the experiment.') flags.DEFINE_boolean('profile', False, 'Whether to produce profiling output.') FLAGS = flags.FLAGS @@ -206,13 +206,13 @@ def train_once( flag_filename = os.path.join(log_dir, 'flags.json') if RANK == 0: - logging.info('saving hparams to %s', hparams_filename) + logging.info('Saving hparams to %s', hparams_filename) with open(hparams_filename, 'w') as f: f.write(json.dumps(hyperparameters._asdict(), indent=2)) - logging.info('saving meta data to %s', meta_filename) + logging.info('Saving meta data to %s', meta_filename) with open(meta_filename, 'w') as f: f.write(json.dumps(meta_data, indent=2)) - logging.info('saving flags to %s', flag_filename) + logging.info('Saving flags to %s', flag_filename) with open(flag_filename, 'w') as f: f.write(json.dumps(flags.FLAGS.flag_values_dict(), indent=2)) metrics_logger = set_up_loggers(log_dir, flags.FLAGS) @@ -436,7 +436,7 @@ def main(_): pytorch_init(USE_PYTORCH_DDP, RANK, profiler) workload_metadata = WORKLOADS[FLAGS.workload] - # extend path according to framework + # Extend path according to framework. workload_metadata['workload_path'] = os.path.join( BASE_WORKLOADS_DIR, workload_metadata['workload_path'] + '_' + FLAGS.framework, @@ -445,12 +445,15 @@ def main(_): workload_path=workload_metadata['workload_path'], workload_class_name=workload_metadata['workload_class_name']) - experiment_name = FLAGS.workload + '_' + FLAGS.framework - if FLAGS.experiment_name != '': - experiment_name = experiment_name + '_' + FLAGS.experiment_nname - experiment_log_dir = os.path.join(FLAGS.experiment_dir, experiment_name) + workload_dir_name = FLAGS.workload + '_' + FLAGS.framework + if FLAGS.experiment_name is None: + experiment_log_dir = os.path.join(FLAGS.experiment_dir, workload_dir_name) + else: + experiment_log_dir = os.path.join(FLAGS.experiment_dir, + FLAGS.experiment_name, + workload_dir_name) if RANK == 0: - # only one worker should create the required dir + # Only one worker should create the required dir. logging.info('Creating experiment directory at %s', experiment_log_dir) os.makedirs(name=experiment_log_dir, exist_ok=True) @@ -471,7 +474,7 @@ def main(_): logging.info(profiler.summary()) if USE_PYTORCH_DDP: - # cleanup + # Cleanup. dist.destroy_process_group() @@ -479,7 +482,6 @@ def main(_): flags.mark_flag_as_required('workload') flags.mark_flag_as_required('framework') flags.mark_flag_as_required('submission_path') - flags.mark_flag_as_required('tuning_ruleset') flags.mark_flag_as_required('tuning_search_space') flags.mark_flag_as_required('experiment_dir') app.run(main) diff --git a/target_setting_runs/README.md b/target_setting_runs/README.md index 542d42b1d..5000a13c7 100644 --- a/target_setting_runs/README.md +++ b/target_setting_runs/README.md @@ -1,137 +1,191 @@ # Target Setting Run replications Original runs were run on Google TPUv2-8 machines. +These are not valid submissions, because they use a different hyperparameter setting per workload. But we include them in order to reproduce how we set the target metric values. + ## Criteo Target was set using AdamW with a linear warmup cosine decay LR schedule. ```bash python3 submission_runner.py \ --framework=jax \ + --data_dir=/home/znado/algorithmic-efficiency \ + --experiment_dir=/home/znado \ + --experiment_name=target_setting \ --workload=criteo1tb \ --submission_path=target_setting_runs/jax_adamw.py \ --tuning_search_space=target_setting_runs/criteo1tb/tuning_search_space.json ``` ```bash -python3 submission_runner.py \ +torchrun --redirects 1:0,2:0,3:0,4:0,5:0,6:0,7:0 --standalone --nnodes=1 --nproc_per_node=8 submission_runner.py \ --framework=pytorch \ + --data_dir=/home/znado/algorithmic-efficiency \ + --experiment_dir=/home/znado \ + --experiment_name=target_setting \ --workload=criteo1tb \ --submission_path=target_setting_runs/pytorch_adamw.py \ --tuning_search_space=target_setting_runs/criteo1tb/tuning_search_space.json ``` -# FastMRI +## FastMRI Target was set using NAdamW with a linear warmup cosine decay LR schedule. ```bash python3 submission_runner.py \ --framework=jax \ + --data_dir=/home/znado \ + --experiment_dir=/home/znado \ + --experiment_name=target_setting \ --workload=fastmri \ --submission_path=target_setting_runs/jax_nadamw.py \ --tuning_search_space=target_setting_runs/fastmri/tuning_search_space.json ``` ```bash -python3 submission_runner.py \ +torchrun --redirects 1:0,2:0,3:0,4:0,5:0,6:0,7:0 --standalone --nnodes=1 --nproc_per_node=8 submission_runner.py \ --framework=pytorch \ + --data_dir=/home/znado \ + --experiment_dir=/home/znado \ + --experiment_name=target_setting \ --workload=fastmri \ --submission_path=target_setting_runs/pytorch_nadamw.py \ --tuning_search_space=target_setting_runs/fastmri/tuning_search_space.json ``` -# ImageNet-Resnet +## ImageNet-Resnet Target was set using Nesterov with a linear warmup and linear decay LR schedule. ```bash python3 submission_runner.py \ --framework=jax \ + --data_dir=/home/znado \ + --imagenet_v2_data_dir=/home/znado \ + --experiment_dir=/home/znado \ + --experiment_name=target_setting \ --workload=imagenet_resnet \ --submission_path=target_setting_runs/jax_nesterov.py \ --tuning_search_space=target_setting_runs/imagenet_resnet/tuning_search_space.json ``` ```bash -python3 submission_runner.py \ +torchrun --redirects 1:0,2:0,3:0,4:0,5:0,6:0,7:0 --standalone --nnodes=1 --nproc_per_node=8 submission_runner.py \ --framework=pytorch \ + --data_dir=/home/znado/imagenet_pytorch \ + --imagenet_v2_data_dir=/home/znado \ + --experiment_dir=/home/znado \ + --experiment_name=target_setting \ --workload=imagenet_resnet \ --submission_path=target_setting_runs/pytorch_nesterov.py \ --tuning_search_space=target_setting_runs/imagenet_resnet/tuning_search_space.json ``` -# ImageNet-ViT +## ImageNet-ViT Target was set using NAdamW with a linear warmup cosine decay LR schedule. ```bash python3 submission_runner.py \ --framework=jax \ + --data_dir=/home/znado \ + --imagenet_v2_data_dir=/home/znado \ + --experiment_dir=/home/znado \ + --experiment_name=target_setting \ --workload=imagenet_vit \ --submission_path=target_setting_runs/jax_nadamw.py \ --tuning_search_space=target_setting_runs/imagenet_vit/tuning_search_space.json ``` ```bash -python3 submission_runner.py \ +torchrun --redirects 1:0,2:0,3:0,4:0,5:0,6:0,7:0 --standalone --nnodes=1 --nproc_per_node=8 submission_runner.py \ --framework=pytorch \ + --data_dir=/home/znado/imagenet_pytorch \ + --imagenet_v2_data_dir=/home/znado \ + --experiment_dir=/home/znado \ + --experiment_name=target_setting \ --workload=imagenet_vit \ --submission_path=target_setting_runs/pytorch_nadamw.py \ --tuning_search_space=target_setting_runs/imagenet_vit/tuning_search_space.json ``` -# Librispeech-Conformer +## Librispeech-Conformer Target was set using AdamW with a linear warmup cosine decay LR schedule. ```bash python3 submission_runner.py \ --framework=jax \ + --data_dir=/home/znado \ + --experiment_dir=/home/znado \ + --experiment_name=target_setting \ --workload=librispeech_conformer \ --submission_path=target_setting_runs/jax_adamw.py \ --tuning_search_space=target_setting_runs/librispeech_conformer/tuning_search_space.json ``` ```bash -python3 submission_runner.py \ +torchrun --redirects 1:0,2:0,3:0,4:0,5:0,6:0,7:0 --standalone --nnodes=1 --nproc_per_node=8 submission_runner.py \ --framework=pytorch \ + --data_dir=/home/znado \ + --experiment_dir=/home/znado \ + --experiment_name=target_setting \ --workload=librispeech_conformer \ --submission_path=target_setting_runs/pytorch_adamw.py \ --tuning_search_space=target_setting_runs/librispeech_conformer/tuning_search_space.json ``` -# Librispeech-Deepspeech +## Librispeech-Deepspeech Target was set using NAdamW with a linear warmup cosine decay LR schedule. ```bash python3 submission_runner.py \ --framework=jax \ + --data_dir=/home/znado \ + --experiment_dir=/home/znado \ + --experiment_name=target_setting \ --workload=librispeech_deepspeech \ --submission_path=target_setting_runs/jax_nadamw.py \ --tuning_search_space=target_setting_runs/librispeech_deepspeech/tuning_search_space.json ``` ```bash -python3 submission_runner.py \ +torchrun --redirects 1:0,2:0,3:0,4:0,5:0,6:0,7:0 --standalone --nnodes=1 --nproc_per_node=8 submission_runner.py \ --framework=pytorch \ + --data_dir=/home/znado \ + --experiment_dir=/home/znado \ + --experiment_name=target_setting \ --workload=librispeech_deepspeech \ --submission_path=target_setting_runs/pytorch_nadamw.py \ --tuning_search_space=target_setting_runs/librispeech_deepspeech/tuning_search_space.json ``` -# OGBG +## OGBG Target was set using Nesterov with a linear warmup and linear decay LR schedule. ```bash python3 submission_runner.py \ --framework=jax \ + --data_dir=/home/znado/tensorflow_datasets \ + --experiment_dir=/home/znado \ + --experiment_name=target_setting \ --workload=ogbg \ --submission_path=target_setting_runs/jax_nesterov.py \ --tuning_search_space=target_setting_runs/ogbg/tuning_search_space.json ``` ```bash -python3 submission_runner.py \ +torchrun --redirects 1:0,2:0,3:0,4:0,5:0,6:0,7:0 --standalone --nnodes=1 --nproc_per_node=8 submission_runner.py \ --framework=pytorch \ + --data_dir=/home/znado/tensorflow_datasets \ + --experiment_dir=/home/znado \ + --experiment_name=target_setting \ --workload=ogbg \ --submission_path=target_setting_runs/pytorch_nesterov.py \ --tuning_search_space=target_setting_runs/ogbg/tuning_search_space.json ``` -# WMT +## WMT Target was set using AdamW with a linear warmup cosine decay LR schedule. ```bash python3 submission_runner.py \ --framework=jax \ + --data_dir=/home/znado/tensorflow_datasets \ + --experiment_dir=/home/znado \ + --experiment_name=target_setting \ --workload=wmt \ --submission_path=target_setting_runs/jax_adamw.py \ --tuning_search_space=target_setting_runs/wmt/tuning_search_space.json ``` ```bash -python3 submission_runner.py \ +torchrun --redirects 1:0,2:0,3:0,4:0,5:0,6:0,7:0 --standalone --nnodes=1 --nproc_per_node=8 submission_runner.py \ --framework=pytorch \ + --data_dir=/home/znado/tensorflow_datasets \ + --experiment_dir=/home/znado \ + --experiment_name=target_setting \ --workload=wmt \ --submission_path=target_setting_runs/pytorch_adamw.py \ --tuning_search_space=target_setting_runs/wmt/tuning_search_space.json diff --git a/target_setting_runs/fastmri/tuning_search_space.json b/target_setting_runs/fastmri/tuning_search_space.json index 1bfb721a5..5c52895a3 100644 --- a/target_setting_runs/fastmri/tuning_search_space.json +++ b/target_setting_runs/fastmri/tuning_search_space.json @@ -1,32 +1,32 @@ { "learning_rate": { "feasible_points": [ - 0.000487 + 0.000808 ] }, "beta1": { "feasible_points": [ - 0.8194 + 0.99578 ] }, "beta2": { "feasible_points": [ - 0.9803 + 0.997681 ] }, "warmup_steps": { "feasible_points": [ - 2171 + 2714 ] }, "num_steps": { "feasible_points": [ - 108568 + 27142 ] }, - "l2": { + "weight_decay": { "feasible_points": [ - 0.407336 + 0.95778 ] } } \ No newline at end of file diff --git a/target_setting_runs/get_batch_size.py b/target_setting_runs/get_batch_size.py index 7b3bc0408..6dcd5396d 100644 --- a/target_setting_runs/get_batch_size.py +++ b/target_setting_runs/get_batch_size.py @@ -3,7 +3,7 @@ def get_batch_size(workload_name): # Return the global batch size. - if workload_name == 'criteo1tb_dlrm': + if workload_name == 'criteo1tb': return 524288 elif workload_name == 'fastmri': return 32 diff --git a/target_setting_runs/imagenet_resnet/tuning_search_space.json b/target_setting_runs/imagenet_resnet/tuning_search_space.json index bf2ea6c1b..80d938b03 100644 --- a/target_setting_runs/imagenet_resnet/tuning_search_space.json +++ b/target_setting_runs/imagenet_resnet/tuning_search_space.json @@ -6,7 +6,7 @@ }, "beta1": { "feasible_points": [ - 0.99 + 0.990086 ] }, "warmup_steps": { @@ -21,7 +21,7 @@ }, "decay_steps_factor": { "feasible_points": [ - 0.9079 + 0.90788 ] }, "end_factor": { @@ -29,9 +29,9 @@ 0.01 ] }, - "l2": { + "weight_decay": { "feasible_points": [ - 7.6e-6 + 7.61e-6 ] }, "label_smoothing": { diff --git a/target_setting_runs/imagenet_vit/tuning_search_space.json b/target_setting_runs/imagenet_vit/tuning_search_space.json index 2dc6cc3fe..14b1cd4a8 100644 --- a/target_setting_runs/imagenet_vit/tuning_search_space.json +++ b/target_setting_runs/imagenet_vit/tuning_search_space.json @@ -21,10 +21,10 @@ }, "num_steps": { "feasible_points": [ - 140752 + 140000 ] }, - "l2": { + "weight_decay": { "feasible_points": [ 0.026595 ] diff --git a/target_setting_runs/jax_adamw.py b/target_setting_runs/jax_adamw.py index a91a513f1..23c05d714 100644 --- a/target_setting_runs/jax_adamw.py +++ b/target_setting_runs/jax_adamw.py @@ -36,7 +36,7 @@ def init_optimizer_state(workload: spec.Workload, b1=hyperparameters.beta1, b2=hyperparameters.beta2, eps=epsilon, - weight_decay=hyperparameters.l2) + weight_decay=hyperparameters.weight_decay) optimizer_state = opt_init_fn(params_zeros_like) return jax_utils.replicate(optimizer_state), opt_update_fn diff --git a/target_setting_runs/jax_nadamw.py b/target_setting_runs/jax_nadamw.py index 2dc57c514..bc3deaab5 100644 --- a/target_setting_runs/jax_nadamw.py +++ b/target_setting_runs/jax_nadamw.py @@ -8,8 +8,6 @@ import jax.numpy as jnp import optax -from algorithmic_efficiency import spec - from algorithmic_efficiency import spec from target_setting_runs import cosine_warmup from target_setting_runs.data_selection import \ @@ -165,7 +163,7 @@ def init_optimizer_state(workload: spec.Workload, b1=hyperparameters.beta1, b2=hyperparameters.beta2, eps=epsilon, - weight_decay=hyperparameters.l2) + weight_decay=hyperparameters.weight_decay) optimizer_state = opt_init_fn(params_zeros_like) return jax_utils.replicate(optimizer_state), opt_update_fn diff --git a/target_setting_runs/jax_nesterov.py b/target_setting_runs/jax_nesterov.py index 92e695b9e..3d3e8b7e2 100644 --- a/target_setting_runs/jax_nesterov.py +++ b/target_setting_runs/jax_nesterov.py @@ -34,7 +34,7 @@ def init_optimizer_state(workload: spec.Workload, workload.param_shapes) opt_init_fn, opt_update_fn = sgd( learning_rate=lr_schedule_fn, - weight_decay=hyperparameters.l2, + weight_decay=hyperparameters.weight_decay, momentum=hyperparameters.beta1, nesterov=True) optimizer_state = opt_init_fn(params_zeros_like) @@ -67,7 +67,7 @@ def sgd(learning_rate, weight_decay, momentum=None, nesterov=False): NOTE: We apply weight decay **before** computing the momentum update. This is equivalent to applying WD after for heavy-ball momentum, - but slightly different when using Nesterov accelleration. This is the same as + but slightly different when using Nesterov acceleration. This is the same as how the Flax optimizers handle weight decay https://flax.readthedocs.io/en/latest/_modules/flax/optim/momentum.html. diff --git a/target_setting_runs/jax_submission_base.py b/target_setting_runs/jax_submission_base.py index 19a0a19fc..0a06b56bd 100644 --- a/target_setting_runs/jax_submission_base.py +++ b/target_setting_runs/jax_submission_base.py @@ -1,6 +1,5 @@ """Update submission function in Jax.""" import functools -from multiprocessing.sharedctypes import Value from typing import Dict, List, Tuple import jax @@ -10,7 +9,6 @@ from algorithmic_efficiency import spec - _GRAD_CLIP_EPS = 1e-6 @@ -40,18 +38,22 @@ def _loss_fn(params): # There was no dropout rate tuning in the target setting runs. dropout_rate=None, aux_dropout_rate=None, - update_batch_norm=False) + update_batch_norm=True) loss = jnp.mean( workload.loss_fn( - batch['targets'], logits, label_smoothing=label_smoothing)) - return loss, (new_model_state, logits) + label_batch=batch['targets'], + logits_batch=logits, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing)) + return loss, new_model_state grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) - (new_model_state, _), grad = grad_fn(current_param_container) + (loss, new_model_state), grad = grad_fn(current_param_container) + del loss grad = lax.pmean(grad, axis_name='batch') if grad_clip is not None: - grad_norm = sum(jnp.sum(g ** 2) for g in jax.tree_leaves(grad)) + grad_norm = sum(jnp.sum(g**2) for g in jax.tree_leaves(grad)) grad_scaling_factor = grad_clip / (grad_norm + _GRAD_CLIP_EPS) grad_scaling_factor = jax.lax.clamp(min=0.0, x=grad_scaling_factor, max=1.0) grad = jax.tree_map(lambda x: x * grad_scaling_factor, grad) @@ -59,7 +61,7 @@ def _loss_fn(params): updates, new_optimizer_state = opt_update_fn(grad, optimizer_state, current_param_container) updated_params = optax.apply_updates(current_param_container, updates) - return new_model_state, new_optimizer_state, updated_params + return new_optimizer_state, updated_params, new_model_state def update_params(workload: spec.Workload, @@ -89,7 +91,7 @@ def update_params(workload: spec.Workload, grad_clip = hyperparameters.grad_clip else: grad_clip = None - new_model_state, new_optimizer_state, new_params = pmapped_train_step( + new_optimizer_state, new_params, new_model_state = pmapped_train_step( workload, opt_update_fn, model_state, optimizer_state, current_param_container, batch, per_device_rngs, grad_clip, label_smoothing) diff --git a/target_setting_runs/librispeech_conformer/tuning_search_space.json b/target_setting_runs/librispeech_conformer/tuning_search_space.json index 87ad14309..d2bfae816 100644 --- a/target_setting_runs/librispeech_conformer/tuning_search_space.json +++ b/target_setting_runs/librispeech_conformer/tuning_search_space.json @@ -16,15 +16,15 @@ }, "warmup_steps": { "feasible_points": [ - 11500 + 10000 ] }, "num_steps": { "feasible_points": [ - 115000 + 100000 ] }, - "l2": { + "weight_decay": { "feasible_points": [ 0.026595 ] diff --git a/target_setting_runs/librispeech_deepspeech/tuning_search_space.json b/target_setting_runs/librispeech_deepspeech/tuning_search_space.json index 3765ef098..07c2626d3 100644 --- a/target_setting_runs/librispeech_deepspeech/tuning_search_space.json +++ b/target_setting_runs/librispeech_deepspeech/tuning_search_space.json @@ -24,7 +24,7 @@ 60000 ] }, - "l2": { + "weight_decay": { "feasible_points": [ 0.107175 ] diff --git a/target_setting_runs/ogbg/tuning_search_space.json b/target_setting_runs/ogbg/tuning_search_space.json index 3765ef098..ef7fb9a44 100644 --- a/target_setting_runs/ogbg/tuning_search_space.json +++ b/target_setting_runs/ogbg/tuning_search_space.json @@ -1,22 +1,17 @@ { "learning_rate": { "feasible_points": [ - 0.002632 + 1.60808 ] }, "beta1": { "feasible_points": [ - 0.9945 - ] - }, - "beta2": { - "feasible_points": [ - 0.9963 + 0.9536 ] }, "warmup_steps": { "feasible_points": [ - 1200 + 3000 ] }, "num_steps": { @@ -24,9 +19,19 @@ 60000 ] }, - "l2": { + "decay_steps_factor": { + "feasible_points": [ + 0.838111 + ] + }, + "end_factor": { + "feasible_points": [ + 0.01 + ] + }, + "weight_decay": { "feasible_points": [ - 0.107175 + 3.1e-7 ] } } \ No newline at end of file diff --git a/target_setting_runs/pytorch_adamw.py b/target_setting_runs/pytorch_adamw.py index 2b15aa116..146c3aba9 100644 --- a/target_setting_runs/pytorch_adamw.py +++ b/target_setting_runs/pytorch_adamw.py @@ -31,7 +31,7 @@ def init_optimizer_state(workload: spec.Workload, lr=hyperparameters.learning_rate, betas=(hyperparameters.beta1, hyperparameters.beta2), eps=epsilon, - weight_decay=hyperparameters.l2) + weight_decay=hyperparameters.weight_decay) } optimizer_state['scheduler'] = cosine_warmup.pytorch_cosine_warmup( diff --git a/target_setting_runs/pytorch_nadamw.py b/target_setting_runs/pytorch_nadamw.py index 55c42e5bd..4d2561723 100644 --- a/target_setting_runs/pytorch_nadamw.py +++ b/target_setting_runs/pytorch_nadamw.py @@ -199,7 +199,7 @@ def init_optimizer_state(workload: spec.Workload, lr=hyperparameters.learning_rate, betas=(hyperparameters.beta1, hyperparameters.beta2), eps=epsilon, - weight_decay=hyperparameters.l2) + weight_decay=hyperparameters.weight_decay) } optimizer_state['scheduler'] = cosine_warmup.pytorch_cosine_warmup( diff --git a/target_setting_runs/pytorch_nesterov.py b/target_setting_runs/pytorch_nesterov.py index d1b0a04c0..b3bdf4817 100644 --- a/target_setting_runs/pytorch_nesterov.py +++ b/target_setting_runs/pytorch_nesterov.py @@ -6,9 +6,9 @@ from algorithmic_efficiency import spec from target_setting_runs.data_selection import \ data_selection # pylint: disable=unused-import -from target_setting_runs.jax_nesterov import create_lr_schedule_fn from target_setting_runs.get_batch_size import \ get_batch_size # pylint: disable=unused-import +from target_setting_runs.jax_nesterov import create_lr_schedule_fn from target_setting_runs.pytorch_submission_base import \ update_params # pylint: disable=unused-import @@ -30,7 +30,7 @@ def init_optimizer_state(workload: spec.Workload, model_params.parameters(), lr=hyperparameters.learning_rate, momentum=hyperparameters.beta1, - weight_decay=hyperparameters.l2, + weight_decay=hyperparameters.weight_decay, nesterov=True) } diff --git a/target_setting_runs/pytorch_submission_base.py b/target_setting_runs/pytorch_submission_base.py index 483de5b60..daaf8abc5 100644 --- a/target_setting_runs/pytorch_submission_base.py +++ b/target_setting_runs/pytorch_submission_base.py @@ -4,6 +4,7 @@ from algorithmic_efficiency import spec + def get_batch_size(workload_name): # Return the global batch size. del workload_name @@ -31,13 +32,14 @@ def update_params(workload: spec.Workload, current_model.train() optimizer_state['optimizer'].zero_grad() - outputs_batch, new_model_state = workload.model_fn( + logits_batch, new_model_state = workload.model_fn( params=current_model, augmented_and_preprocessed_input_batch=batch, model_state=model_state, mode=spec.ForwardPassMode.TRAIN, rng=rng, - dropout_rate=0.0, # Default. + # There was no dropout rate tuning in the target setting runs. + dropout_rate=None, aux_dropout_rate=None, update_batch_norm=True) @@ -46,7 +48,8 @@ def update_params(workload: spec.Workload, 'label_smoothing') else 0.0) loss = workload.loss_fn( label_batch=batch['targets'], - outputs_batch=outputs_batch, + logits_batch=logits_batch, + mask_batch=batch.get('weights'), label_smoothing=label_smoothing).mean() loss.backward() diff --git a/target_setting_runs/wmt/tuning_search_space.json b/target_setting_runs/wmt/tuning_search_space.json index 3d5a42cdc..0502704e9 100644 --- a/target_setting_runs/wmt/tuning_search_space.json +++ b/target_setting_runs/wmt/tuning_search_space.json @@ -1,37 +1,37 @@ { "learning_rate": { "feasible_points": [ - 0.000844 + 0.000921 ] }, "beta1": { "feasible_points": [ - 0.8895 + 0.850323 ] }, "beta2": { "feasible_points": [ - 0.9978 + 0.996825 ] }, "warmup_steps": { "feasible_points": [ - 1000 + 2000 ] }, "num_steps": { "feasible_points": [ - 50000 + 100000 ] }, - "l2": { + "weight_decay": { "feasible_points": [ - 0.081354 + 0.026842 ] }, "label_smoothing": { "feasible_points": [ - 0.1 + 0.140451 ] } } \ No newline at end of file