From b24812ff74d1353a2d56d3cffb86952298836f04 Mon Sep 17 00:00:00 2001 From: Aaron Defazio Date: Thu, 5 Sep 2024 15:20:04 +0000 Subject: [PATCH 01/22] BN Fixes --- algorithmic_efficiency/pytorch_utils.py | 14 ++++++++------ .../librispeech_pytorch/models.py | 11 ++++++----- .../librispeech_pytorch/models.py | 10 +++++----- 3 files changed, 19 insertions(+), 16 deletions(-) diff --git a/algorithmic_efficiency/pytorch_utils.py b/algorithmic_efficiency/pytorch_utils.py index 4f6c254bd..2e5828912 100644 --- a/algorithmic_efficiency/pytorch_utils.py +++ b/algorithmic_efficiency/pytorch_utils.py @@ -57,7 +57,6 @@ def sync_ddp_time(time: float, device: torch.device) -> float: dist.all_reduce(time_tensor, op=dist.ReduceOp.MAX) return time_tensor.item() - def update_batch_norm_fn(module: spec.ParameterContainer, update_batch_norm: bool) -> None: bn_layers = ( @@ -67,10 +66,13 @@ def update_batch_norm_fn(module: spec.ParameterContainer, ) if isinstance(module, bn_layers): if not update_batch_norm: - module.eval() - module.momentum_backup = module.momentum + if not hasattr(module, 'momentum_backup'): + module.momentum_backup = module.momentum + # module.momentum can be float or torch.Tensor. - module.momentum = 0. * module.momentum_backup + if torch.is_tensor(module.momentum_backup): + module.momentum = torch.zeros_like(module.momentum_backup) + else: + module.momentum = 0.0 elif hasattr(module, 'momentum_backup'): - module.momentum = module.momentum_backup - module.track_running_stats = update_batch_norm + module.momentum = module.momentum_backup \ No newline at end of file diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/models.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/models.py index 502cb093e..cab73df4a 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/models.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/models.py @@ -40,7 +40,7 @@ class ConformerConfig: time_masks_per_frame: float = 0.0 use_dynamic_time_mask_max_frames: bool = True input_dropout_rate: float = 0.1 - batch_norm_momentum: float = 0.999 + batch_norm_momentum: float = 1 - 0.999 batch_norm_epsilon: float = 0.001 use_specaug: bool = True attention_temperature: float = 1.0 @@ -369,10 +369,11 @@ def forward(self, inputs, input_paddings): mean = (masked_inp).sum(dim=(0, 1)) / count var = (torch.square(masked_inp - mean) * mask).sum(dim=(0, 1)) / count - self.running_mean = self.momentum * self.running_mean + ( - 1 - self.momentum) * mean.detach() - self.running_var = self.momentum * self.running_var + ( - 1 - self.momentum) * var.detach() + self.running_mean = (1 - self.momentum) * self.running_mean + ( + self.momentum) * mean.detach() + self.running_var = (1 - self.momentum) * self.running_var + ( + self.momentum) * var.detach() + else: mean = self.running_mean var = self.running_var diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/models.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/models.py index a5ee3fa0a..bdf556f1c 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/models.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_pytorch/models.py @@ -36,7 +36,7 @@ class DeepspeechConfig: time_mask_max_ratio: float = 0.05 time_masks_per_frame: float = 0.0 use_dynamic_time_mask_max_frames: bool = True - batch_norm_momentum: float = 0.999 + batch_norm_momentum: float = 1 - 0.999 batch_norm_epsilon: float = 0.001 # If None, defaults to 0.1. input_dropout_rate: Optional[float] = 0.1 @@ -264,10 +264,10 @@ def forward(self, inputs, input_paddings): sum_ = dist_nn.all_reduce(sum_) var = sum_ / count - self.running_mean = self.momentum * self.running_mean + ( - 1 - self.momentum) * mean.detach() - self.running_var = self.momentum * self.running_var + ( - 1 - self.momentum) * var.detach() + self.running_mean = (1 - self.momentum) * self.running_mean + ( + self.momentum) * mean.detach() + self.running_var = (1 - self.momentum) * self.running_var + ( + self.momentum) * var.detach() else: mean = self.running_mean var = self.running_var From e09bbf594150dae74b186ee354daa23d3f29de25 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Thu, 3 Oct 2024 17:39:20 +0200 Subject: [PATCH 02/22] add `train_state` to all instances of `update_params', passing it by (shallow) copy in submission_runner --- DOCUMENTATION.md | 1 + algorithmic_efficiency/spec.py | 2 ++ .../external_tuning/jax_nadamw_full_budget.py | 2 ++ .../external_tuning/jax_nadamw_target_setting.py | 2 ++ .../external_tuning/pytorch_nadamw_full_budget.py | 4 +++- .../external_tuning/pytorch_nadamw_target_setting.py | 4 +++- .../self_tuning/jax_nadamw_full_budget.py | 2 ++ .../self_tuning/jax_nadamw_target_setting.py | 2 ++ .../self_tuning/pytorch_nadamw_full_budget.py | 4 +++- .../self_tuning/pytorch_nadamw_target_setting.py | 4 +++- .../development_algorithms/cifar/cifar_jax/submission.py | 4 +++- .../development_algorithms/cifar/cifar_pytorch/submission.py | 4 +++- .../development_algorithms/mnist/mnist_jax/submission.py | 4 +++- .../development_algorithms/mnist/mnist_pytorch/submission.py | 4 +++- .../paper_baselines/adafactor/jax/submission.py | 4 +++- .../paper_baselines/adafactor/pytorch/submission.py | 4 +++- reference_algorithms/paper_baselines/adamw/jax/submission.py | 4 +++- .../paper_baselines/adamw/pytorch/submission.py | 4 +++- reference_algorithms/paper_baselines/lamb/jax/submission.py | 4 +++- .../paper_baselines/lamb/pytorch/submission.py | 4 +++- .../paper_baselines/momentum/jax/submission.py | 4 +++- .../paper_baselines/momentum/pytorch/submission.py | 4 +++- reference_algorithms/paper_baselines/nadamw/jax/submission.py | 2 ++ .../paper_baselines/nadamw/pytorch/submission.py | 4 +++- .../paper_baselines/nesterov/jax/submission.py | 4 +++- .../paper_baselines/nesterov/pytorch/submission.py | 4 +++- reference_algorithms/paper_baselines/sam/jax/submission.py | 4 +++- .../paper_baselines/sam/pytorch/submission.py | 4 +++- .../paper_baselines/shampoo/jax/submission.py | 4 +++- .../target_setting_algorithms/jax_submission_base.py | 4 +++- .../target_setting_algorithms/pytorch_submission_base.py | 4 +++- submission_runner.py | 1 + submissions/template/submission.py | 3 ++- 33 files changed, 88 insertions(+), 25 deletions(-) diff --git a/DOCUMENTATION.md b/DOCUMENTATION.md index 607f47ead..8207691d6 100644 --- a/DOCUMENTATION.md +++ b/DOCUMENTATION.md @@ -199,6 +199,7 @@ def update_params( batch: Dict[str, Tensor], loss_type: LossType, optimizer_state: OptimizerState, + train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, rng: RandomState diff --git a/algorithmic_efficiency/spec.py b/algorithmic_efficiency/spec.py index 285983957..7a16f0040 100644 --- a/algorithmic_efficiency/spec.py +++ b/algorithmic_efficiency/spec.py @@ -401,6 +401,7 @@ def init_optimizer_state(workload: Workload, Dict[str, Tensor], LossType, OptimizerState, + Dict[str, Any], List[Tuple[int, float]], int, RandomState @@ -422,6 +423,7 @@ def update_params(workload: Workload, batch: Dict[str, Tensor], loss_type: LossType, optimizer_state: OptimizerState, + train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, rng: RandomState) -> UpdateReturn: diff --git a/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py b/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py index 98193f01f..63cf25fe5 100644 --- a/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py +++ b/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py @@ -260,12 +260,14 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, + train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, rng: spec.RandomState) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results optimizer_state, opt_update_fn = optimizer_state diff --git a/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py b/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py index 66fdc4ebb..ab0ee82b1 100644 --- a/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py +++ b/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py @@ -260,12 +260,14 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, + train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, rng: spec.RandomState) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results optimizer_state, opt_update_fn = optimizer_state diff --git a/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py b/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py index ebc49d428..c85cc6dd3 100644 --- a/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py +++ b/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py @@ -1,7 +1,7 @@ """Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch.""" import math -from typing import Dict, Iterator, List, Tuple +from typing import Dict, Iterator, List, Tuple, Any from absl import logging import torch @@ -232,12 +232,14 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, + train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, rng: spec.RandomState) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results current_model = current_param_container diff --git a/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py b/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py index 524bc20af..bb1278911 100644 --- a/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py +++ b/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py @@ -1,7 +1,7 @@ """Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch.""" import math -from typing import Dict, Iterator, List, Tuple +from typing import Dict, Iterator, List, Tuple, Any from absl import logging import torch @@ -232,12 +232,14 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, + train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, rng: spec.RandomState) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results current_model = current_param_container diff --git a/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py b/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py index 4f53afb56..f6ada3c8e 100644 --- a/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py +++ b/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py @@ -272,12 +272,14 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, + train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, rng: spec.RandomState) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results del hyperparameters diff --git a/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py b/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py index 60a1f784d..9c7f66c43 100644 --- a/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py +++ b/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py @@ -272,12 +272,14 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, + train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, rng: spec.RandomState) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results del hyperparameters diff --git a/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py b/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py index f8e87ec2a..2af6d548a 100644 --- a/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py +++ b/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py @@ -1,7 +1,7 @@ """Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch.""" import math -from typing import Dict, Iterator, List, Tuple +from typing import Dict, Iterator, List, Tuple, Any from absl import logging import torch @@ -244,12 +244,14 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, + train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, rng: spec.RandomState) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results del hyperparameters diff --git a/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py b/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py index 1de26417f..2e2385e29 100644 --- a/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py +++ b/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py @@ -1,7 +1,7 @@ """Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch.""" import math -from typing import Dict, Iterator, List, Tuple +from typing import Dict, Iterator, List, Tuple, Any from absl import logging import torch @@ -244,12 +244,14 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, + train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, rng: spec.RandomState) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results del hyperparameters diff --git a/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py b/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py index 2971efe9a..89aeac238 100644 --- a/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py +++ b/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py @@ -1,7 +1,7 @@ """Training algorithm track submission functions for CIFAR10.""" import functools -from typing import Dict, Iterator, List, Tuple +from typing import Dict, Iterator, List, Tuple, Any from flax import jax_utils import jax @@ -118,6 +118,7 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, + train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, rng: spec.RandomState) -> spec.UpdateReturn: @@ -125,6 +126,7 @@ def update_params(workload: spec.Workload, del current_params_types del loss_type del global_step + del train_state del eval_results optimizer_state, opt_update_fn = optimizer_state per_device_rngs = jax.random.split(rng, jax.local_device_count()) diff --git a/reference_algorithms/development_algorithms/cifar/cifar_pytorch/submission.py b/reference_algorithms/development_algorithms/cifar/cifar_pytorch/submission.py index 358c6bffc..bcdab6fc3 100644 --- a/reference_algorithms/development_algorithms/cifar/cifar_pytorch/submission.py +++ b/reference_algorithms/development_algorithms/cifar/cifar_pytorch/submission.py @@ -1,6 +1,6 @@ """Training algorithm track submission functions for CIFAR10.""" -from typing import Dict, Iterator, List, Tuple +from typing import Dict, Iterator, List, Tuple, Any import torch from torch.optim.lr_scheduler import CosineAnnealingLR @@ -61,6 +61,7 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, + train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, rng: spec.RandomState) -> spec.UpdateReturn: @@ -68,6 +69,7 @@ def update_params(workload: spec.Workload, del current_params_types del hyperparameters del loss_type + del train_state del eval_results current_model = current_param_container diff --git a/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py b/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py index 896609d51..01a266eaf 100644 --- a/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py +++ b/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py @@ -1,7 +1,7 @@ """Training algorithm track submission functions for MNIST.""" import functools -from typing import Dict, Iterator, List, Tuple +from typing import Dict, Iterator, List, Tuple, Any from flax import jax_utils import jax @@ -83,12 +83,14 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, + train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, rng: spec.RandomState) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results del global_step diff --git a/reference_algorithms/development_algorithms/mnist/mnist_pytorch/submission.py b/reference_algorithms/development_algorithms/mnist/mnist_pytorch/submission.py index f1601e606..e72a9d823 100644 --- a/reference_algorithms/development_algorithms/mnist/mnist_pytorch/submission.py +++ b/reference_algorithms/development_algorithms/mnist/mnist_pytorch/submission.py @@ -1,6 +1,6 @@ """Training algorithm track submission functions for MNIST.""" -from typing import Dict, Iterator, List, Tuple +from typing import Dict, Iterator, List, Tuple, Any import torch @@ -40,6 +40,7 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, + train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, rng: spec.RandomState) -> spec.UpdateReturn: @@ -47,6 +48,7 @@ def update_params(workload: spec.Workload, del hyperparameters del loss_type del current_params_types + del train_state del eval_results del global_step diff --git a/reference_algorithms/paper_baselines/adafactor/jax/submission.py b/reference_algorithms/paper_baselines/adafactor/jax/submission.py index 2dd85c29b..ea440cce7 100644 --- a/reference_algorithms/paper_baselines/adafactor/jax/submission.py +++ b/reference_algorithms/paper_baselines/adafactor/jax/submission.py @@ -1,7 +1,7 @@ """Submission file for an Adafactor optimizer with warmup+cosine LR in Jax.""" import functools -from typing import Dict, Iterator, List, Tuple +from typing import Dict, Iterator, List, Tuple, Any from flax import jax_utils import jax @@ -118,12 +118,14 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, + train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, rng: spec.RandomState) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results optimizer_state, opt_update_fn = optimizer_state diff --git a/reference_algorithms/paper_baselines/adafactor/pytorch/submission.py b/reference_algorithms/paper_baselines/adafactor/pytorch/submission.py index e6fef17dc..30d6942e4 100644 --- a/reference_algorithms/paper_baselines/adafactor/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/adafactor/pytorch/submission.py @@ -1,7 +1,7 @@ """Submission file for Adafactor in PyTorch.""" from functools import partial -from typing import Dict, Iterator, List, Tuple +from typing import Dict, Iterator, List, Tuple, Any from absl import logging import torch @@ -198,12 +198,14 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, + train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, rng: spec.RandomState) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results current_model = current_param_container diff --git a/reference_algorithms/paper_baselines/adamw/jax/submission.py b/reference_algorithms/paper_baselines/adamw/jax/submission.py index 80a963600..935d0d0ca 100644 --- a/reference_algorithms/paper_baselines/adamw/jax/submission.py +++ b/reference_algorithms/paper_baselines/adamw/jax/submission.py @@ -1,7 +1,7 @@ """Submission file for an AdamW optimizer with warmup+cosine LR in Jax.""" import functools -from typing import Dict, Iterator, List, Tuple +from typing import Dict, Iterator, List, Tuple, Any from flax import jax_utils import jax @@ -118,12 +118,14 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, + train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, rng: spec.RandomState) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results optimizer_state, opt_update_fn = optimizer_state diff --git a/reference_algorithms/paper_baselines/adamw/pytorch/submission.py b/reference_algorithms/paper_baselines/adamw/pytorch/submission.py index 32353e5b4..ddd17b3b2 100644 --- a/reference_algorithms/paper_baselines/adamw/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/adamw/pytorch/submission.py @@ -1,6 +1,6 @@ """Submission file for an AdamW optimizer with warmup+cosine LR in PyTorch.""" -from typing import Dict, Iterator, List, Tuple +from typing import Dict, Iterator, List, Tuple, Any from absl import logging import torch @@ -59,12 +59,14 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, + train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, rng: spec.RandomState) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results current_model = current_param_container diff --git a/reference_algorithms/paper_baselines/lamb/jax/submission.py b/reference_algorithms/paper_baselines/lamb/jax/submission.py index 27d635ee9..3944d6483 100644 --- a/reference_algorithms/paper_baselines/lamb/jax/submission.py +++ b/reference_algorithms/paper_baselines/lamb/jax/submission.py @@ -1,7 +1,7 @@ """Submission file for a LAMB optimizer with warmup+cosine LR in Jax.""" import functools -from typing import Dict, Iterator, List, Tuple +from typing import Dict, Iterator, List, Tuple, Any from flax import jax_utils import jax @@ -126,12 +126,14 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, + train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, rng: spec.RandomState) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results optimizer_state, opt_update_fn = optimizer_state diff --git a/reference_algorithms/paper_baselines/lamb/pytorch/submission.py b/reference_algorithms/paper_baselines/lamb/pytorch/submission.py index 7d0d8763e..20bc80d23 100644 --- a/reference_algorithms/paper_baselines/lamb/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/lamb/pytorch/submission.py @@ -1,7 +1,7 @@ """Submission file for a LAMB optimizer with warmup+cosine LR in PyTorch.""" import math -from typing import Dict, Iterator, List, Tuple +from typing import Dict, Iterator, List, Tuple, Any from absl import logging import torch @@ -197,12 +197,14 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, + train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, rng: spec.RandomState) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results current_model = current_param_container diff --git a/reference_algorithms/paper_baselines/momentum/jax/submission.py b/reference_algorithms/paper_baselines/momentum/jax/submission.py index cccb3c1b5..b2db0c728 100644 --- a/reference_algorithms/paper_baselines/momentum/jax/submission.py +++ b/reference_algorithms/paper_baselines/momentum/jax/submission.py @@ -1,7 +1,7 @@ """Submission file for a SGD with HeavyBall momentum optimizer in Jax.""" import functools -from typing import Callable, Dict, Iterator, List, Tuple +from typing import Callable, Dict, Iterator, List, Tuple, Any from flax import jax_utils import jax @@ -152,12 +152,14 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, + train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, rng: spec.RandomState) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results optimizer_state, opt_update_fn = optimizer_state diff --git a/reference_algorithms/paper_baselines/momentum/pytorch/submission.py b/reference_algorithms/paper_baselines/momentum/pytorch/submission.py index ec5c0b31c..533e9fed4 100644 --- a/reference_algorithms/paper_baselines/momentum/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/momentum/pytorch/submission.py @@ -1,6 +1,6 @@ """Submission file for a SGD with HeavyBall momentum optimizer in PyTorch.""" -from typing import Callable, Dict, Iterator, List, Tuple +from typing import Callable, Dict, Iterator, List, Tuple, Any from absl import logging import optax @@ -75,12 +75,14 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, + train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, rng: spec.RandomState) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results current_model = current_param_container diff --git a/reference_algorithms/paper_baselines/nadamw/jax/submission.py b/reference_algorithms/paper_baselines/nadamw/jax/submission.py index 98193f01f..63cf25fe5 100644 --- a/reference_algorithms/paper_baselines/nadamw/jax/submission.py +++ b/reference_algorithms/paper_baselines/nadamw/jax/submission.py @@ -260,12 +260,14 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, + train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, rng: spec.RandomState) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results optimizer_state, opt_update_fn = optimizer_state diff --git a/reference_algorithms/paper_baselines/nadamw/pytorch/submission.py b/reference_algorithms/paper_baselines/nadamw/pytorch/submission.py index ebc49d428..c85cc6dd3 100644 --- a/reference_algorithms/paper_baselines/nadamw/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/nadamw/pytorch/submission.py @@ -1,7 +1,7 @@ """Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch.""" import math -from typing import Dict, Iterator, List, Tuple +from typing import Dict, Iterator, List, Tuple, Any from absl import logging import torch @@ -232,12 +232,14 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, + train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, rng: spec.RandomState) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results current_model = current_param_container diff --git a/reference_algorithms/paper_baselines/nesterov/jax/submission.py b/reference_algorithms/paper_baselines/nesterov/jax/submission.py index f3b0aeed4..f79bc34b4 100644 --- a/reference_algorithms/paper_baselines/nesterov/jax/submission.py +++ b/reference_algorithms/paper_baselines/nesterov/jax/submission.py @@ -1,7 +1,7 @@ """Submission file for a SGD with Nesterov momentum optimizer in Jax.""" import functools -from typing import Callable, Dict, Iterator, List, Tuple +from typing import Callable, Dict, Iterator, List, Tuple, Any from flax import jax_utils import jax @@ -152,12 +152,14 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, + train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, rng: spec.RandomState) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results optimizer_state, opt_update_fn = optimizer_state diff --git a/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py b/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py index fe9154934..330e344c1 100644 --- a/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py @@ -1,6 +1,6 @@ """Submission file for a SGD with Nesterov momentum optimizer in PyTorch.""" -from typing import Callable, Dict, Iterator, List, Tuple +from typing import Callable, Dict, Iterator, List, Tuple, Any from absl import logging import optax @@ -75,12 +75,14 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, + train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, rng: spec.RandomState) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results current_model = current_param_container diff --git a/reference_algorithms/paper_baselines/sam/jax/submission.py b/reference_algorithms/paper_baselines/sam/jax/submission.py index 85b3d7441..5448ff1f2 100644 --- a/reference_algorithms/paper_baselines/sam/jax/submission.py +++ b/reference_algorithms/paper_baselines/sam/jax/submission.py @@ -1,7 +1,7 @@ """Submission file for a SAM optimizer with warmup+cosine LR in Jax.""" import functools -from typing import Dict, Iterator, List, Optional, Tuple +from typing import Dict, Iterator, List, Optional, Tuple, Any from flax import jax_utils import jax @@ -205,12 +205,14 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, + train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, rng: spec.RandomState) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results optimizer_state, opt_update_fn = optimizer_state diff --git a/reference_algorithms/paper_baselines/sam/pytorch/submission.py b/reference_algorithms/paper_baselines/sam/pytorch/submission.py index 2cab75972..967d53549 100644 --- a/reference_algorithms/paper_baselines/sam/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/sam/pytorch/submission.py @@ -1,6 +1,6 @@ """Submission file for a SAM optimizer with warmup+cosine LR in PyTorch.""" -from typing import Callable, Dict, Iterator, List, Tuple +from typing import Callable, Dict, Iterator, List, Tuple, Any from absl import logging import torch @@ -139,12 +139,14 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, + train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, rng: spec.RandomState) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results current_model = current_param_container diff --git a/reference_algorithms/paper_baselines/shampoo/jax/submission.py b/reference_algorithms/paper_baselines/shampoo/jax/submission.py index 9c6b66b7f..104ae3ce3 100644 --- a/reference_algorithms/paper_baselines/shampoo/jax/submission.py +++ b/reference_algorithms/paper_baselines/shampoo/jax/submission.py @@ -1,7 +1,7 @@ """Submission file for a Shampoo optimizer with warmup+cosine LR in Jax.""" import functools -from typing import Dict, Iterator, List, Tuple +from typing import Dict, Iterator, List, Tuple, Any from flax import jax_utils import jax @@ -121,12 +121,14 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, + train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, rng: spec.RandomState) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results optimizer_state, opt_update_fn = optimizer_state diff --git a/reference_algorithms/target_setting_algorithms/jax_submission_base.py b/reference_algorithms/target_setting_algorithms/jax_submission_base.py index 2a641b520..e66b1ab23 100644 --- a/reference_algorithms/target_setting_algorithms/jax_submission_base.py +++ b/reference_algorithms/target_setting_algorithms/jax_submission_base.py @@ -1,6 +1,6 @@ """Update submission function in Jax.""" import functools -from typing import Dict, List, Tuple +from typing import Dict, List, Tuple, Any import jax from jax import lax @@ -77,12 +77,14 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, + train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, rng: spec.RandomState) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results optimizer_state, opt_update_fn = optimizer_state diff --git a/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py b/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py index f9e40212b..c031f3ac4 100644 --- a/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py +++ b/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py @@ -1,6 +1,6 @@ """Batch size and update submission functions in PyTorch.""" -from typing import Dict, List, Tuple +from typing import Dict, List, Tuple, Any from absl import logging import torch @@ -20,12 +20,14 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, + train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, rng: spec.RandomState) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type + del train_state del eval_results current_model = current_param_container diff --git a/submission_runner.py b/submission_runner.py index 551173bf5..aef7fafb0 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -357,6 +357,7 @@ def train_once( batch=batch, loss_type=workload.loss_type, optimizer_state=optimizer_state, + train_state=train_state.copy(), eval_results=eval_results, global_step=global_step, rng=update_rng) diff --git a/submissions/template/submission.py b/submissions/template/submission.py index 5ef195db5..fb9b1cad1 100644 --- a/submissions/template/submission.py +++ b/submissions/template/submission.py @@ -4,7 +4,7 @@ and https://github.com/mlcommons/algorithmic-efficiency/blob/main/DOCUMENTATION.md#disallowed-submissions for guidelines. """ -from typing import Dict, Iterator, List, Tuple +from typing import Dict, Iterator, List, Tuple, Any from algorithmic_efficiency import spec @@ -30,6 +30,7 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, + train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, rng: spec.RandomState) -> spec.UpdateReturn: From f15e227ee96ca181d670d6dd06bada647986c9ee Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Thu, 3 Oct 2024 17:44:41 +0200 Subject: [PATCH 03/22] update DOCS --- DOCUMENTATION.md | 1 + 1 file changed, 1 insertion(+) diff --git a/DOCUMENTATION.md b/DOCUMENTATION.md index 8207691d6..8722a441e 100644 --- a/DOCUMENTATION.md +++ b/DOCUMENTATION.md @@ -213,6 +213,7 @@ def update_params( - The `loss_fn` produces a loss per example and a summed loss (both only for one device), which both can be used. - Allowed to update state for the optimizer. - Uses the `model_fn` of the `workload` in order to decouple the loss from the model so that model outputs (forward passes) can be reused (by storing them in the optimizer state). +- The submission can access the elapsed training time and get further information about the evaluation through `train_state`. - The submission can access the target evaluation metric via the `workload` variable. - **A call to this function will be considered a step** - The time between a call to this function and the next call to this function will be considered the per-step time. From 107c6b6e2ad3312d77ad4e99034f17a31f2967c6 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Thu, 3 Oct 2024 18:05:01 +0200 Subject: [PATCH 04/22] update test --- tests/reference_algorithm_tests.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/reference_algorithm_tests.py b/tests/reference_algorithm_tests.py index 74c06e180..938c4fa11 100644 --- a/tests/reference_algorithm_tests.py +++ b/tests/reference_algorithm_tests.py @@ -471,6 +471,7 @@ def _test_submission(workload_name, batch=batch, loss_type=workload.loss_type, optimizer_state=optimizer_state, + train_state={}, eval_results=[], global_step=global_step, rng=update_rng) From d4ad0eb06df5f323a4d383e70715eeb99181d294 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Thu, 3 Oct 2024 19:02:05 +0200 Subject: [PATCH 05/22] fix linting --- reference_algorithms/paper_baselines/momentum/jax/submission.py | 2 +- .../paper_baselines/momentum/pytorch/submission.py | 2 +- reference_algorithms/paper_baselines/nesterov/jax/submission.py | 2 +- .../paper_baselines/nesterov/pytorch/submission.py | 2 +- reference_algorithms/paper_baselines/sam/pytorch/submission.py | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/reference_algorithms/paper_baselines/momentum/jax/submission.py b/reference_algorithms/paper_baselines/momentum/jax/submission.py index b2db0c728..dc101896b 100644 --- a/reference_algorithms/paper_baselines/momentum/jax/submission.py +++ b/reference_algorithms/paper_baselines/momentum/jax/submission.py @@ -1,7 +1,7 @@ """Submission file for a SGD with HeavyBall momentum optimizer in Jax.""" import functools -from typing import Callable, Dict, Iterator, List, Tuple, Any +from typing import Any, Callable, Dict, Iterator, List, Tuple from flax import jax_utils import jax diff --git a/reference_algorithms/paper_baselines/momentum/pytorch/submission.py b/reference_algorithms/paper_baselines/momentum/pytorch/submission.py index 533e9fed4..52aba82bf 100644 --- a/reference_algorithms/paper_baselines/momentum/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/momentum/pytorch/submission.py @@ -1,6 +1,6 @@ """Submission file for a SGD with HeavyBall momentum optimizer in PyTorch.""" -from typing import Callable, Dict, Iterator, List, Tuple, Any +from typing import Any, Callable, Dict, Iterator, List, Tuple from absl import logging import optax diff --git a/reference_algorithms/paper_baselines/nesterov/jax/submission.py b/reference_algorithms/paper_baselines/nesterov/jax/submission.py index f79bc34b4..e47c7fa0c 100644 --- a/reference_algorithms/paper_baselines/nesterov/jax/submission.py +++ b/reference_algorithms/paper_baselines/nesterov/jax/submission.py @@ -1,7 +1,7 @@ """Submission file for a SGD with Nesterov momentum optimizer in Jax.""" import functools -from typing import Callable, Dict, Iterator, List, Tuple, Any +from typing import Any, Callable, Dict, Iterator, List, Tuple from flax import jax_utils import jax diff --git a/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py b/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py index 330e344c1..442949866 100644 --- a/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py @@ -1,6 +1,6 @@ """Submission file for a SGD with Nesterov momentum optimizer in PyTorch.""" -from typing import Callable, Dict, Iterator, List, Tuple, Any +from typing import Any, Callable, Dict, Iterator, List, Tuple from absl import logging import optax diff --git a/reference_algorithms/paper_baselines/sam/pytorch/submission.py b/reference_algorithms/paper_baselines/sam/pytorch/submission.py index 967d53549..15b6b6858 100644 --- a/reference_algorithms/paper_baselines/sam/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/sam/pytorch/submission.py @@ -1,6 +1,6 @@ """Submission file for a SAM optimizer with warmup+cosine LR in PyTorch.""" -from typing import Callable, Dict, Iterator, List, Tuple, Any +from typing import Any, Callable, Dict, Iterator, List, Tuple from absl import logging import torch From cb7e162230d6ca3849a183435d05c6f802498de1 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Thu, 3 Oct 2024 19:03:42 +0200 Subject: [PATCH 06/22] fix isort --- .../external_tuning/pytorch_nadamw_full_budget.py | 2 +- .../external_tuning/pytorch_nadamw_target_setting.py | 2 +- .../self_tuning/pytorch_nadamw_full_budget.py | 2 +- .../self_tuning/pytorch_nadamw_target_setting.py | 2 +- .../development_algorithms/cifar/cifar_jax/submission.py | 2 +- .../development_algorithms/cifar/cifar_pytorch/submission.py | 2 +- .../development_algorithms/mnist/mnist_jax/submission.py | 2 +- .../development_algorithms/mnist/mnist_pytorch/submission.py | 2 +- .../paper_baselines/adafactor/jax/submission.py | 2 +- .../paper_baselines/adafactor/pytorch/submission.py | 2 +- reference_algorithms/paper_baselines/adamw/jax/submission.py | 2 +- .../paper_baselines/adamw/pytorch/submission.py | 2 +- reference_algorithms/paper_baselines/lamb/jax/submission.py | 2 +- reference_algorithms/paper_baselines/lamb/pytorch/submission.py | 2 +- .../paper_baselines/nadamw/pytorch/submission.py | 2 +- reference_algorithms/paper_baselines/shampoo/jax/submission.py | 2 +- submissions/template/submission.py | 2 +- 17 files changed, 17 insertions(+), 17 deletions(-) diff --git a/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py b/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py index c85cc6dd3..72a3bf289 100644 --- a/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py +++ b/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py @@ -1,7 +1,7 @@ """Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch.""" import math -from typing import Dict, Iterator, List, Tuple, Any +from typing import Any, Dict, Iterator, List, Tuple from absl import logging import torch diff --git a/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py b/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py index bb1278911..934538b63 100644 --- a/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py +++ b/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py @@ -1,7 +1,7 @@ """Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch.""" import math -from typing import Dict, Iterator, List, Tuple, Any +from typing import Any, Dict, Iterator, List, Tuple from absl import logging import torch diff --git a/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py b/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py index 2af6d548a..f968d4abf 100644 --- a/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py +++ b/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py @@ -1,7 +1,7 @@ """Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch.""" import math -from typing import Dict, Iterator, List, Tuple, Any +from typing import Any, Dict, Iterator, List, Tuple from absl import logging import torch diff --git a/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py b/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py index 2e2385e29..14c22141c 100644 --- a/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py +++ b/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py @@ -1,7 +1,7 @@ """Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch.""" import math -from typing import Dict, Iterator, List, Tuple, Any +from typing import Any, Dict, Iterator, List, Tuple from absl import logging import torch diff --git a/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py b/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py index 89aeac238..7e41e9fd7 100644 --- a/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py +++ b/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py @@ -1,7 +1,7 @@ """Training algorithm track submission functions for CIFAR10.""" import functools -from typing import Dict, Iterator, List, Tuple, Any +from typing import Any, Dict, Iterator, List, Tuple from flax import jax_utils import jax diff --git a/reference_algorithms/development_algorithms/cifar/cifar_pytorch/submission.py b/reference_algorithms/development_algorithms/cifar/cifar_pytorch/submission.py index bcdab6fc3..81110bae6 100644 --- a/reference_algorithms/development_algorithms/cifar/cifar_pytorch/submission.py +++ b/reference_algorithms/development_algorithms/cifar/cifar_pytorch/submission.py @@ -1,6 +1,6 @@ """Training algorithm track submission functions for CIFAR10.""" -from typing import Dict, Iterator, List, Tuple, Any +from typing import Any, Dict, Iterator, List, Tuple import torch from torch.optim.lr_scheduler import CosineAnnealingLR diff --git a/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py b/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py index 01a266eaf..3f75c9904 100644 --- a/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py +++ b/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py @@ -1,7 +1,7 @@ """Training algorithm track submission functions for MNIST.""" import functools -from typing import Dict, Iterator, List, Tuple, Any +from typing import Any, Dict, Iterator, List, Tuple from flax import jax_utils import jax diff --git a/reference_algorithms/development_algorithms/mnist/mnist_pytorch/submission.py b/reference_algorithms/development_algorithms/mnist/mnist_pytorch/submission.py index e72a9d823..d326f4035 100644 --- a/reference_algorithms/development_algorithms/mnist/mnist_pytorch/submission.py +++ b/reference_algorithms/development_algorithms/mnist/mnist_pytorch/submission.py @@ -1,6 +1,6 @@ """Training algorithm track submission functions for MNIST.""" -from typing import Dict, Iterator, List, Tuple, Any +from typing import Any, Dict, Iterator, List, Tuple import torch diff --git a/reference_algorithms/paper_baselines/adafactor/jax/submission.py b/reference_algorithms/paper_baselines/adafactor/jax/submission.py index ea440cce7..39cf3d4f9 100644 --- a/reference_algorithms/paper_baselines/adafactor/jax/submission.py +++ b/reference_algorithms/paper_baselines/adafactor/jax/submission.py @@ -1,7 +1,7 @@ """Submission file for an Adafactor optimizer with warmup+cosine LR in Jax.""" import functools -from typing import Dict, Iterator, List, Tuple, Any +from typing import Any, Dict, Iterator, List, Tuple from flax import jax_utils import jax diff --git a/reference_algorithms/paper_baselines/adafactor/pytorch/submission.py b/reference_algorithms/paper_baselines/adafactor/pytorch/submission.py index 30d6942e4..880f9168d 100644 --- a/reference_algorithms/paper_baselines/adafactor/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/adafactor/pytorch/submission.py @@ -1,7 +1,7 @@ """Submission file for Adafactor in PyTorch.""" from functools import partial -from typing import Dict, Iterator, List, Tuple, Any +from typing import Any, Dict, Iterator, List, Tuple from absl import logging import torch diff --git a/reference_algorithms/paper_baselines/adamw/jax/submission.py b/reference_algorithms/paper_baselines/adamw/jax/submission.py index 935d0d0ca..06eeacb39 100644 --- a/reference_algorithms/paper_baselines/adamw/jax/submission.py +++ b/reference_algorithms/paper_baselines/adamw/jax/submission.py @@ -1,7 +1,7 @@ """Submission file for an AdamW optimizer with warmup+cosine LR in Jax.""" import functools -from typing import Dict, Iterator, List, Tuple, Any +from typing import Any, Dict, Iterator, List, Tuple from flax import jax_utils import jax diff --git a/reference_algorithms/paper_baselines/adamw/pytorch/submission.py b/reference_algorithms/paper_baselines/adamw/pytorch/submission.py index ddd17b3b2..0710fb9a0 100644 --- a/reference_algorithms/paper_baselines/adamw/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/adamw/pytorch/submission.py @@ -1,6 +1,6 @@ """Submission file for an AdamW optimizer with warmup+cosine LR in PyTorch.""" -from typing import Dict, Iterator, List, Tuple, Any +from typing import Any, Dict, Iterator, List, Tuple from absl import logging import torch diff --git a/reference_algorithms/paper_baselines/lamb/jax/submission.py b/reference_algorithms/paper_baselines/lamb/jax/submission.py index 3944d6483..891da63be 100644 --- a/reference_algorithms/paper_baselines/lamb/jax/submission.py +++ b/reference_algorithms/paper_baselines/lamb/jax/submission.py @@ -1,7 +1,7 @@ """Submission file for a LAMB optimizer with warmup+cosine LR in Jax.""" import functools -from typing import Dict, Iterator, List, Tuple, Any +from typing import Any, Dict, Iterator, List, Tuple from flax import jax_utils import jax diff --git a/reference_algorithms/paper_baselines/lamb/pytorch/submission.py b/reference_algorithms/paper_baselines/lamb/pytorch/submission.py index 20bc80d23..7886dc75d 100644 --- a/reference_algorithms/paper_baselines/lamb/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/lamb/pytorch/submission.py @@ -1,7 +1,7 @@ """Submission file for a LAMB optimizer with warmup+cosine LR in PyTorch.""" import math -from typing import Dict, Iterator, List, Tuple, Any +from typing import Any, Dict, Iterator, List, Tuple from absl import logging import torch diff --git a/reference_algorithms/paper_baselines/nadamw/pytorch/submission.py b/reference_algorithms/paper_baselines/nadamw/pytorch/submission.py index c85cc6dd3..72a3bf289 100644 --- a/reference_algorithms/paper_baselines/nadamw/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/nadamw/pytorch/submission.py @@ -1,7 +1,7 @@ """Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch.""" import math -from typing import Dict, Iterator, List, Tuple, Any +from typing import Any, Dict, Iterator, List, Tuple from absl import logging import torch diff --git a/reference_algorithms/paper_baselines/shampoo/jax/submission.py b/reference_algorithms/paper_baselines/shampoo/jax/submission.py index 104ae3ce3..e853a821b 100644 --- a/reference_algorithms/paper_baselines/shampoo/jax/submission.py +++ b/reference_algorithms/paper_baselines/shampoo/jax/submission.py @@ -1,7 +1,7 @@ """Submission file for a Shampoo optimizer with warmup+cosine LR in Jax.""" import functools -from typing import Dict, Iterator, List, Tuple, Any +from typing import Any, Dict, Iterator, List, Tuple from flax import jax_utils import jax diff --git a/submissions/template/submission.py b/submissions/template/submission.py index fb9b1cad1..9bfb23367 100644 --- a/submissions/template/submission.py +++ b/submissions/template/submission.py @@ -4,7 +4,7 @@ and https://github.com/mlcommons/algorithmic-efficiency/blob/main/DOCUMENTATION.md#disallowed-submissions for guidelines. """ -from typing import Dict, Iterator, List, Tuple, Any +from typing import Any, Dict, Iterator, List, Tuple from algorithmic_efficiency import spec From 1f59285fa1ae8eb8e4cce10cba4db486bf49f8e8 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Thu, 3 Oct 2024 19:05:43 +0200 Subject: [PATCH 07/22] fix import sort --- reference_algorithms/paper_baselines/sam/jax/submission.py | 2 +- .../target_setting_algorithms/jax_submission_base.py | 2 +- .../target_setting_algorithms/pytorch_submission_base.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/reference_algorithms/paper_baselines/sam/jax/submission.py b/reference_algorithms/paper_baselines/sam/jax/submission.py index 5448ff1f2..95bea68aa 100644 --- a/reference_algorithms/paper_baselines/sam/jax/submission.py +++ b/reference_algorithms/paper_baselines/sam/jax/submission.py @@ -1,7 +1,7 @@ """Submission file for a SAM optimizer with warmup+cosine LR in Jax.""" import functools -from typing import Dict, Iterator, List, Optional, Tuple, Any +from typing import Any, Dict, Iterator, List, Optional, Tuple from flax import jax_utils import jax diff --git a/reference_algorithms/target_setting_algorithms/jax_submission_base.py b/reference_algorithms/target_setting_algorithms/jax_submission_base.py index e66b1ab23..a98d134fc 100644 --- a/reference_algorithms/target_setting_algorithms/jax_submission_base.py +++ b/reference_algorithms/target_setting_algorithms/jax_submission_base.py @@ -1,6 +1,6 @@ """Update submission function in Jax.""" import functools -from typing import Dict, List, Tuple, Any +from typing import Any, Dict, List, Tuple import jax from jax import lax diff --git a/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py b/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py index c031f3ac4..586429e37 100644 --- a/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py +++ b/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py @@ -1,6 +1,6 @@ """Batch size and update submission functions in PyTorch.""" -from typing import Dict, List, Tuple, Any +from typing import Any, Dict, List, Tuple from absl import logging import torch From f574bf04dda725f790bb6ffaf2ca62b260b132d8 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 17 Oct 2024 01:01:14 +0000 Subject: [PATCH 08/22] add use_running_average_bn arg for jax --- .../workloads/cifar/cifar_jax/models.py | 9 +++- .../workloads/cifar/cifar_jax/workload.py | 9 ++-- .../imagenet_resnet/imagenet_jax/models.py | 9 +++- .../imagenet_resnet/imagenet_jax/workload.py | 9 ++-- .../librispeech_jax/models.py | 49 ++++++++++++------- .../librispeech_jax/workload.py | 9 ++-- 6 files changed, 63 insertions(+), 31 deletions(-) diff --git a/algorithmic_efficiency/workloads/cifar/cifar_jax/models.py b/algorithmic_efficiency/workloads/cifar/cifar_jax/models.py index 834c93b7a..09338ca82 100644 --- a/algorithmic_efficiency/workloads/cifar/cifar_jax/models.py +++ b/algorithmic_efficiency/workloads/cifar/cifar_jax/models.py @@ -28,11 +28,16 @@ class ResNet(nn.Module): @nn.compact def __call__(self, x: spec.Tensor, - update_batch_norm: bool = True) -> spec.Tensor: + update_batch_norm: bool = True, + use_running_average_bn: bool = None) -> spec.Tensor: conv = functools.partial(nn.Conv, use_bias=False, dtype=self.dtype) + + # Preserve default behavior for backwards compatibility + if use_running_average_bn is None: + use_running_average_bn = not update_batch_norm norm = functools.partial( nn.BatchNorm, - use_running_average=not update_batch_norm, + use_running_average=use_running_average_bn, momentum=0.9, epsilon=1e-5, dtype=self.dtype) diff --git a/algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py b/algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py index b019d1cee..019dde38c 100644 --- a/algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py +++ b/algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py @@ -110,7 +110,8 @@ def model_fn( model_state: spec.ModelAuxiliaryState, mode: spec.ForwardPassMode, rng: spec.RandomState, - update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + update_batch_norm: bool, + use_running_average_bn: Optional[bool] = None) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del mode del rng variables = {'params': params, **model_state} @@ -119,14 +120,16 @@ def model_fn( variables, augmented_and_preprocessed_input_batch['inputs'], update_batch_norm=update_batch_norm, - mutable=['batch_stats']) + mutable=['batch_stats'], + use_running_average_bn=use_running_average_bn) return logits, new_model_state else: logits = self._model.apply( variables, augmented_and_preprocessed_input_batch['inputs'], update_batch_norm=update_batch_norm, - mutable=False) + mutable=False, + use_running_average_bn=use_running_average_bn) return logits, model_state # Does NOT apply regularization, which is left to the submitter to do in diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/models.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/models.py index 99a9b0513..2e680cbd9 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/models.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/models.py @@ -84,11 +84,16 @@ class ResNet(nn.Module): @nn.compact def __call__(self, x: spec.Tensor, - update_batch_norm: bool = True) -> spec.Tensor: + update_batch_norm: bool = True, + use_running_average_bn: Optional[bool] = None) -> spec.Tensor: conv = functools.partial(nn.Conv, use_bias=False, dtype=self.dtype) + + # Preserve default behavior for backwards compatibility + if use_running_average_bn is None: + use_running_average_bn = not update_batch_norm norm = functools.partial( nn.BatchNorm, - use_running_average=not update_batch_norm, + use_running_average=use_running_average_bn, momentum=0.9, epsilon=1e-5, dtype=self.dtype) diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py index d8de214f5..46168c2a0 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py @@ -148,7 +148,8 @@ def model_fn( model_state: spec.ModelAuxiliaryState, mode: spec.ForwardPassMode, rng: spec.RandomState, - update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + update_batch_norm: bool, + use_running_average_bn: Optional[bool] = None) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del mode del rng variables = {'params': params, **model_state} @@ -157,14 +158,16 @@ def model_fn( variables, augmented_and_preprocessed_input_batch['inputs'], update_batch_norm=update_batch_norm, - mutable=['batch_stats']) + mutable=['batch_stats'], + use_running_average_bn=use_running_average_bn) return logits, new_model_state else: logits = self._model.apply( variables, augmented_and_preprocessed_input_batch['inputs'], update_batch_norm=update_batch_norm, - mutable=False) + mutable=False, + use_running_average_bn=use_running_average_bn) return logits, model_state # Does NOT apply regularization, which is left to the submitter to do in diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py index ed05f4335..077ff0f89 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py @@ -454,7 +454,7 @@ def setup(self): self.beta = self.param('bias', nn.initializers.zeros, dim, dtype) @nn.compact - def __call__(self, inputs, input_paddings, train): + def __call__(self, inputs, input_paddings, update_batch_norm, use_running_average_bn): rank = inputs.ndim reduce_over_dims = list(range(0, rank - 1)) @@ -462,7 +462,12 @@ def __call__(self, inputs, input_paddings, train): momentum = self.config.batch_norm_momentum epsilon = self.config.batch_norm_epsilon - if train: + if use_running_average_bn: + mean = self.ra_mean.value + var = self.ra_var.value + + else: + # compute batch statistics mask = 1.0 - padding sum_v = jnp.sum(inputs * mask, axis=reduce_over_dims, keepdims=True) count_v = jnp.sum( @@ -477,17 +482,14 @@ def __call__(self, inputs, input_paddings, train): keepdims=True) var = sum_vv / count_v - - self.ra_mean.value = momentum * \ - self.ra_mean.value + (1 - momentum) * mean - self.ra_var.value = momentum * \ - self.ra_var.value + (1 - momentum) * var - else: - mean = self.ra_mean.value - var = self.ra_var.value - + + if update_batch_norm: + self.ra_mean.value = momentum * \ + self.ra_mean.value + (1 - momentum) * mean + self.ra_var.value = momentum * \ + self.ra_var.value + (1 - momentum) * var + inv = (1 + self.gamma) / jnp.sqrt(var + epsilon) - bn_output = (inputs - mean) * inv + self.beta bn_output *= 1.0 - padding @@ -517,7 +519,7 @@ class ConvolutionBlock(nn.Module): config: ConformerConfig @nn.compact - def __call__(self, inputs, input_paddings, train): + def __call__(self, inputs, input_paddings, train, update_batch_norm, use_running_average_bn): config = self.config inputs = LayerNorm(dim=config.encoder_dim)(inputs) @@ -546,7 +548,7 @@ def __call__(self, inputs, input_paddings, train): kernel_init=nn.initializers.xavier_uniform())( inputs) - inputs = BatchNorm(config)(inputs, input_paddings, train) + inputs = BatchNorm(config)(inputs, input_paddings, update_batch_norm, use_running_average_bn) if config.activation_function_name == 'swish': activation_fn = nn.swish elif config.activation_function_name == 'gelu': @@ -586,7 +588,7 @@ class ConformerBlock(nn.Module): config: ConformerConfig @nn.compact - def __call__(self, inputs, input_paddings, train): + def __call__(self, inputs, input_paddings, train, update_batch_norm, use_running_average): config = self.config padding_mask = jnp.expand_dims(1 - input_paddings, -1) @@ -597,7 +599,7 @@ def __call__(self, inputs, input_paddings, train): inputs, input_paddings, train) inputs = inputs + \ - ConvolutionBlock(config)(inputs, input_paddings, train) + ConvolutionBlock(config)(inputs, input_paddings, train, update_batch_norm, use_running_average) inputs = inputs + 0.5 * FeedForwardModule(config=self.config)( inputs, padding_mask, train) @@ -629,12 +631,23 @@ def setup(self): .use_dynamic_time_mask_max_frames) @nn.compact - def __call__(self, inputs, input_paddings, train): + def __call__(self, + inputs, + input_paddings, + train, + update_batch_norm: Optional[bool] = None, + use_running_average_bn: Optional[bool] = None): config = self.config outputs = inputs output_paddings = input_paddings + # Set BN args if not supplied for backwards compatibility + if update_batch_norm is None: + update_batch_norm = train + if use_running_average_bn is None: + use_running_average_bn = not train + # Compute normalized log mel spectrograms from input audio signal. preprocessing_config = preprocessor.LibrispeechPreprocessingConfig() outputs, output_paddings = preprocessor.MelFilterbankFrontend( @@ -660,7 +673,7 @@ def __call__(self, inputs, input_paddings, train): # Run the conformer encoder layers. for _ in range(config.num_encoder_layers): - outputs = ConformerBlock(config)(outputs, output_paddings, train) + outputs = ConformerBlock(config)(outputs, output_paddings, train, update_batch_norm, use_running_average_bn) outputs = LayerNorm(config.encoder_dim)(outputs) # Run the decoder which in this case is a trivial projection layer. diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py index f4d1ab0f3..6c55acfb0 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py @@ -107,7 +107,8 @@ def model_fn( model_state: spec.ModelAuxiliaryState, mode: spec.ForwardPassMode, rng: spec.RandomState, - update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + update_batch_norm: bool, + use_running_average_bn: Optional[bool]=None) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: variables = {'params': params, **model_state} inputs, input_paddings = augmented_and_preprocessed_input_batch['inputs'] is_train_mode = mode == spec.ForwardPassMode.TRAIN @@ -118,7 +119,8 @@ def model_fn( input_paddings, train=True, rngs={'dropout' : rng}, - mutable=['batch_stats']) + mutable=['batch_stats'], + use_running_average_bn=use_running_average_bn) return (logits, logit_paddings), new_model_state else: logits, logit_paddings = self._model.apply( @@ -126,7 +128,8 @@ def model_fn( inputs, input_paddings, train=False, - mutable=False) + mutable=False, + use_running_average_bn=use_running_average_bn) return (logits, logit_paddings), model_state def _build_input_queue( From 7ca8365a7aba462181736b8d39382162c9bb1ad6 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 17 Oct 2024 22:01:58 +0000 Subject: [PATCH 09/22] formatting --- algorithmic_efficiency/pytorch_utils.py | 3 +- .../workloads/cifar/cifar_jax/models.py | 6 +-- .../workloads/cifar/cifar_jax/workload.py | 3 +- .../imagenet_resnet/imagenet_jax/models.py | 4 +- .../imagenet_resnet/imagenet_jax/workload.py | 3 +- .../librispeech_jax/models.py | 49 +++++++++++++------ .../librispeech_jax/workload.py | 3 +- .../librispeech_pytorch/models.py | 2 +- 8 files changed, 49 insertions(+), 24 deletions(-) diff --git a/algorithmic_efficiency/pytorch_utils.py b/algorithmic_efficiency/pytorch_utils.py index 2e5828912..590f500fa 100644 --- a/algorithmic_efficiency/pytorch_utils.py +++ b/algorithmic_efficiency/pytorch_utils.py @@ -57,6 +57,7 @@ def sync_ddp_time(time: float, device: torch.device) -> float: dist.all_reduce(time_tensor, op=dist.ReduceOp.MAX) return time_tensor.item() + def update_batch_norm_fn(module: spec.ParameterContainer, update_batch_norm: bool) -> None: bn_layers = ( @@ -75,4 +76,4 @@ def update_batch_norm_fn(module: spec.ParameterContainer, else: module.momentum = 0.0 elif hasattr(module, 'momentum_backup'): - module.momentum = module.momentum_backup \ No newline at end of file + module.momentum = module.momentum_backup diff --git a/algorithmic_efficiency/workloads/cifar/cifar_jax/models.py b/algorithmic_efficiency/workloads/cifar/cifar_jax/models.py index 09338ca82..059352fb6 100644 --- a/algorithmic_efficiency/workloads/cifar/cifar_jax/models.py +++ b/algorithmic_efficiency/workloads/cifar/cifar_jax/models.py @@ -31,10 +31,10 @@ def __call__(self, update_batch_norm: bool = True, use_running_average_bn: bool = None) -> spec.Tensor: conv = functools.partial(nn.Conv, use_bias=False, dtype=self.dtype) - - # Preserve default behavior for backwards compatibility + + # Preserve default behavior for backwards compatibility if use_running_average_bn is None: - use_running_average_bn = not update_batch_norm + use_running_average_bn = not update_batch_norm norm = functools.partial( nn.BatchNorm, use_running_average=use_running_average_bn, diff --git a/algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py b/algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py index 019dde38c..8268c6ca3 100644 --- a/algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py +++ b/algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py @@ -111,7 +111,8 @@ def model_fn( mode: spec.ForwardPassMode, rng: spec.RandomState, update_batch_norm: bool, - use_running_average_bn: Optional[bool] = None) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + use_running_average_bn: Optional[bool] = None + ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del mode del rng variables = {'params': params, **model_state} diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/models.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/models.py index 2e680cbd9..34cd17440 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/models.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/models.py @@ -88,9 +88,9 @@ def __call__(self, use_running_average_bn: Optional[bool] = None) -> spec.Tensor: conv = functools.partial(nn.Conv, use_bias=False, dtype=self.dtype) - # Preserve default behavior for backwards compatibility + # Preserve default behavior for backwards compatibility if use_running_average_bn is None: - use_running_average_bn = not update_batch_norm + use_running_average_bn = not update_batch_norm norm = functools.partial( nn.BatchNorm, use_running_average=use_running_average_bn, diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py index 46168c2a0..2747fc2db 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/workload.py @@ -149,7 +149,8 @@ def model_fn( mode: spec.ForwardPassMode, rng: spec.RandomState, update_batch_norm: bool, - use_running_average_bn: Optional[bool] = None) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + use_running_average_bn: Optional[bool] = None + ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: del mode del rng variables = {'params': params, **model_state} diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py index 077ff0f89..db92f56d4 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py @@ -454,7 +454,11 @@ def setup(self): self.beta = self.param('bias', nn.initializers.zeros, dim, dtype) @nn.compact - def __call__(self, inputs, input_paddings, update_batch_norm, use_running_average_bn): + def __call__(self, + inputs, + input_paddings, + update_batch_norm, + use_running_average_bn): rank = inputs.ndim reduce_over_dims = list(range(0, rank - 1)) @@ -462,7 +466,7 @@ def __call__(self, inputs, input_paddings, update_batch_norm, use_running_averag momentum = self.config.batch_norm_momentum epsilon = self.config.batch_norm_epsilon - if use_running_average_bn: + if use_running_average_bn: mean = self.ra_mean.value var = self.ra_var.value @@ -482,13 +486,13 @@ def __call__(self, inputs, input_paddings, update_batch_norm, use_running_averag keepdims=True) var = sum_vv / count_v - + if update_batch_norm: self.ra_mean.value = momentum * \ self.ra_mean.value + (1 - momentum) * mean self.ra_var.value = momentum * \ self.ra_var.value + (1 - momentum) * var - + inv = (1 + self.gamma) / jnp.sqrt(var + epsilon) bn_output = (inputs - mean) * inv + self.beta bn_output *= 1.0 - padding @@ -519,7 +523,12 @@ class ConvolutionBlock(nn.Module): config: ConformerConfig @nn.compact - def __call__(self, inputs, input_paddings, train, update_batch_norm, use_running_average_bn): + def __call__(self, + inputs, + input_paddings, + train, + update_batch_norm, + use_running_average_bn): config = self.config inputs = LayerNorm(dim=config.encoder_dim)(inputs) @@ -548,7 +557,10 @@ def __call__(self, inputs, input_paddings, train, update_batch_norm, use_running kernel_init=nn.initializers.xavier_uniform())( inputs) - inputs = BatchNorm(config)(inputs, input_paddings, update_batch_norm, use_running_average_bn) + inputs = BatchNorm(config)(inputs, + input_paddings, + update_batch_norm, + use_running_average_bn) if config.activation_function_name == 'swish': activation_fn = nn.swish elif config.activation_function_name == 'gelu': @@ -588,7 +600,12 @@ class ConformerBlock(nn.Module): config: ConformerConfig @nn.compact - def __call__(self, inputs, input_paddings, train, update_batch_norm, use_running_average): + def __call__(self, + inputs, + input_paddings, + train, + update_batch_norm, + use_running_average): config = self.config padding_mask = jnp.expand_dims(1 - input_paddings, -1) @@ -631,12 +648,12 @@ def setup(self): .use_dynamic_time_mask_max_frames) @nn.compact - def __call__(self, - inputs, - input_paddings, - train, - update_batch_norm: Optional[bool] = None, - use_running_average_bn: Optional[bool] = None): + def __call__(self, + inputs, + input_paddings, + train, + update_batch_norm: Optional[bool] = None, + use_running_average_bn: Optional[bool] = None): config = self.config outputs = inputs @@ -673,7 +690,11 @@ def __call__(self, # Run the conformer encoder layers. for _ in range(config.num_encoder_layers): - outputs = ConformerBlock(config)(outputs, output_paddings, train, update_batch_norm, use_running_average_bn) + outputs = ConformerBlock(config)(outputs, + output_paddings, + train, + update_batch_norm, + use_running_average_bn) outputs = LayerNorm(config.encoder_dim)(outputs) # Run the decoder which in this case is a trivial projection layer. diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py index 6c55acfb0..e362f973b 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py @@ -108,7 +108,8 @@ def model_fn( mode: spec.ForwardPassMode, rng: spec.RandomState, update_batch_norm: bool, - use_running_average_bn: Optional[bool]=None) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + use_running_average_bn: Optional[bool] = None + ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: variables = {'params': params, **model_state} inputs, input_paddings = augmented_and_preprocessed_input_batch['inputs'] is_train_mode = mode == spec.ForwardPassMode.TRAIN diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/models.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/models.py index cab73df4a..61400806a 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/models.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_pytorch/models.py @@ -373,7 +373,7 @@ def forward(self, inputs, input_paddings): self.momentum) * mean.detach() self.running_var = (1 - self.momentum) * self.running_var + ( self.momentum) * var.detach() - + else: mean = self.running_mean var = self.running_var From 39132387e411a5e869a98c5b57fbd1e7b0d12194 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 17 Oct 2024 22:37:04 +0000 Subject: [PATCH 10/22] formatting --- .../librispeech_conformer/librispeech_jax/models.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py index db92f56d4..a7f786c32 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py @@ -616,7 +616,12 @@ def __call__(self, inputs, input_paddings, train) inputs = inputs + \ - ConvolutionBlock(config)(inputs, input_paddings, train, update_batch_norm, use_running_average) + ConvolutionBlock(config)(inputs, + input_paddings, + train, + update_batch_norm, + use_running_average + ) inputs = inputs + 0.5 * FeedForwardModule(config=self.config)( inputs, padding_mask, train) From baac0a452871ef5a940c07dbfd64f6d3b9c5427d Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 17 Oct 2024 23:26:42 +0000 Subject: [PATCH 11/22] formatting --- .../librispeech_conformer/librispeech_jax/models.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py index a7f786c32..cb6287c5e 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/models.py @@ -616,10 +616,10 @@ def __call__(self, inputs, input_paddings, train) inputs = inputs + \ - ConvolutionBlock(config)(inputs, - input_paddings, - train, - update_batch_norm, + ConvolutionBlock(config)(inputs, + input_paddings, + train, + update_batch_norm, use_running_average ) From 087fd5c1e8400d1ab162cbf79e0fad6a828dae5f Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Thu, 17 Oct 2024 23:58:10 +0000 Subject: [PATCH 12/22] debugging --- .../workloads/librispeech_conformer/librispeech_jax/workload.py | 1 + 1 file changed, 1 insertion(+) diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py index e362f973b..3caf151ab 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py @@ -113,6 +113,7 @@ def model_fn( variables = {'params': params, **model_state} inputs, input_paddings = augmented_and_preprocessed_input_batch['inputs'] is_train_mode = mode == spec.ForwardPassMode.TRAIN + print(type(use_running_average_bn)) if update_batch_norm or is_train_mode: (logits, logit_paddings), new_model_state = self._model.apply( variables, From c5c36c291f2c2a5a21bc0b60961a7016039e93ae Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Fri, 18 Oct 2024 00:30:40 +0000 Subject: [PATCH 13/22] add seperate model_fn for deepspeech jax without use_running_average_bn --- .../librispeech_jax/workload.py | 1 - .../librispeech_jax/workload.py | 31 +++++++++++++++++++ tests/reference_algorithm_tests.py | 1 + 3 files changed, 32 insertions(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py index 3caf151ab..e362f973b 100644 --- a/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_conformer/librispeech_jax/workload.py @@ -113,7 +113,6 @@ def model_fn( variables = {'params': params, **model_state} inputs, input_paddings = augmented_and_preprocessed_input_batch['inputs'] is_train_mode = mode == spec.ForwardPassMode.TRAIN - print(type(use_running_average_bn)) if update_batch_norm or is_train_mode: (logits, logit_paddings), new_model_state = self._model.apply( variables, diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py index 8473fac0f..c81b1b0b4 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py @@ -55,6 +55,37 @@ def init_model_fn( model_state = jax_utils.replicate(model_state) params = jax_utils.replicate(params) return params, model_state + + def model_fn( + self, + params: spec.ParameterContainer, + augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], + model_state: spec.ModelAuxiliaryState, + mode: spec.ForwardPassMode, + rng: spec.RandomState, + update_batch_norm: bool, + use_running_average_bn: Optional[bool] = None + ) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]: + variables = {'params': params, **model_state} + inputs, input_paddings = augmented_and_preprocessed_input_batch['inputs'] + is_train_mode = mode == spec.ForwardPassMode.TRAIN + if update_batch_norm or is_train_mode: + (logits, logit_paddings), new_model_state = self._model.apply( + variables, + inputs, + input_paddings, + train=True, + rngs={'dropout' : rng}, + mutable=['batch_stats']) + return (logits, logit_paddings), new_model_state + else: + logits, logit_paddings = self._model.apply( + variables, + inputs, + input_paddings, + train=False, + mutable=False) + return (logits, logit_paddings), model_state def is_output_params(self, param_key: spec.ParameterKey) -> bool: return param_key == 'Dense_0' diff --git a/tests/reference_algorithm_tests.py b/tests/reference_algorithm_tests.py index 74c06e180..5e563d2f9 100644 --- a/tests/reference_algorithm_tests.py +++ b/tests/reference_algorithm_tests.py @@ -408,6 +408,7 @@ def _test_submission(workload_name, workload_path=workload_metadata['workload_path'], workload_class_name=workload_metadata['workload_class_name'], return_class=True) + print(f'Workload class for {workload_name} is {workload_class}') submission_module_path = workloads.convert_filepath_to_module(submission_path) submission_module = importlib.import_module(submission_module_path) From 783aab4a3c2952823290c3e3881b0e423231a2ae Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Fri, 18 Oct 2024 00:32:36 +0000 Subject: [PATCH 14/22] fix syntax error --- .../librispeech_deepspeech/librispeech_jax/workload.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py index c81b1b0b4..05fdf90e7 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py @@ -56,7 +56,7 @@ def init_model_fn( params = jax_utils.replicate(params) return params, model_state - def model_fn( + def model_fn( self, params: spec.ParameterContainer, augmented_and_preprocessed_input_batch: Dict[str, spec.Tensor], From 28e7e21a334001f3e62ce15875f3126af0affbd6 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Fri, 18 Oct 2024 00:37:29 +0000 Subject: [PATCH 15/22] fix --- .../librispeech_deepspeech/librispeech_jax/workload.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py index 05fdf90e7..2d46960ed 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py @@ -1,5 +1,5 @@ import functools -from typing import Optional +from typing import Optional, Dict, Tuple from flax import jax_utils import jax From b063f9f7fa5288736c006ff111454e77869ada8f Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Fri, 18 Oct 2024 00:43:44 +0000 Subject: [PATCH 16/22] fix import order --- .../librispeech_deepspeech/librispeech_jax/workload.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py index 2d46960ed..e5030f426 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py @@ -1,5 +1,5 @@ import functools -from typing import Optional, Dict, Tuple +from typing import Dict, Optional, Tuple from flax import jax_utils import jax From 894cd872aa07d97e2a4fe0ce9e5a0dcdb790bfb8 Mon Sep 17 00:00:00 2001 From: Priya Kasimbeg Date: Fri, 18 Oct 2024 00:53:11 +0000 Subject: [PATCH 17/22] formatting --- .../librispeech_deepspeech/librispeech_jax/workload.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py index e5030f426..a0db6d607 100644 --- a/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py +++ b/algorithmic_efficiency/workloads/librispeech_deepspeech/librispeech_jax/workload.py @@ -55,7 +55,7 @@ def init_model_fn( model_state = jax_utils.replicate(model_state) params = jax_utils.replicate(params) return params, model_state - + def model_fn( self, params: spec.ParameterContainer, From ce8eb182043258fc2d7823d84bcc4591441dc159 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Fri, 25 Oct 2024 10:50:06 +0200 Subject: [PATCH 18/22] ensure backward compatibility --- algorithmic_efficiency/spec.py | 8 ++++---- submission_runner.py | 11 +++++++++-- 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/algorithmic_efficiency/spec.py b/algorithmic_efficiency/spec.py index 7a16f0040..7bc86b505 100644 --- a/algorithmic_efficiency/spec.py +++ b/algorithmic_efficiency/spec.py @@ -401,10 +401,10 @@ def init_optimizer_state(workload: Workload, Dict[str, Tensor], LossType, OptimizerState, - Dict[str, Any], List[Tuple[int, float]], int, - RandomState + RandomState, + Optional[Dict[str, Any]] ], UpdateReturn] @@ -423,10 +423,10 @@ def update_params(workload: Workload, batch: Dict[str, Tensor], loss_type: LossType, optimizer_state: OptimizerState, - train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, - rng: RandomState) -> UpdateReturn: + rng: RandomState, + train_state: Optional[Dict[str, Any]] = None) -> UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" pass diff --git a/submission_runner.py b/submission_runner.py index aef7fafb0..1a66acc58 100644 --- a/submission_runner.py +++ b/submission_runner.py @@ -17,11 +17,13 @@ import datetime import gc import importlib +from inspect import signature import itertools import json import os import struct import time +from types import MappingProxyType from typing import Any, Dict, Optional, Tuple from absl import app @@ -273,6 +275,10 @@ def train_once( hyperparameters, opt_init_rng) logging.info('Initializing metrics bundle.') + + # Check if 'train_state' is in the function signature + needs_train_state = 'train_state' in signature(update_params).parameters + # Bookkeeping. train_state = { 'validation_goal_reached': False, @@ -357,10 +363,11 @@ def train_once( batch=batch, loss_type=workload.loss_type, optimizer_state=optimizer_state, - train_state=train_state.copy(), eval_results=eval_results, global_step=global_step, - rng=update_rng) + rng=update_rng, + **({'train_state': MappingProxyType(train_state)} + if needs_train_state else {})) except spec.TrainingCompleteError: train_state['training_complete'] = True global_step += 1 From 5a06a0dc670014db1eeebb1ec1960e984e379db0 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Fri, 25 Oct 2024 11:38:05 +0200 Subject: [PATCH 19/22] adding train_state to all submissions --- .../external_tuning/jax_nadamw_full_budget.py | 5 +++-- .../external_tuning/jax_nadamw_target_setting.py | 5 +++-- .../external_tuning/pytorch_nadamw_full_budget.py | 7 ++++--- .../external_tuning/pytorch_nadamw_target_setting.py | 7 ++++--- .../self_tuning/jax_nadamw_full_budget.py | 5 +++-- .../self_tuning/jax_nadamw_target_setting.py | 5 +++-- .../self_tuning/pytorch_nadamw_full_budget.py | 7 ++++--- .../self_tuning/pytorch_nadamw_target_setting.py | 7 ++++--- .../development_algorithms/cifar/cifar_jax/submission.py | 7 ++++--- .../cifar/cifar_pytorch/submission.py | 7 ++++--- .../development_algorithms/mnist/mnist_jax/submission.py | 7 ++++--- .../mnist/mnist_pytorch/submission.py | 7 ++++--- .../paper_baselines/adafactor/jax/submission.py | 7 ++++--- .../paper_baselines/adafactor/pytorch/submission.py | 7 ++++--- .../paper_baselines/adamw/jax/submission.py | 7 ++++--- .../paper_baselines/adamw/pytorch/submission.py | 7 ++++--- .../paper_baselines/lamb/jax/submission.py | 7 ++++--- .../paper_baselines/lamb/pytorch/submission.py | 7 ++++--- .../paper_baselines/momentum/jax/submission.py | 5 +++-- .../paper_baselines/momentum/pytorch/submission.py | 5 +++-- .../paper_baselines/nadamw/jax/submission.py | 5 +++-- .../paper_baselines/nadamw/pytorch/submission.py | 7 ++++--- .../paper_baselines/nesterov/jax/submission.py | 5 +++-- .../paper_baselines/nesterov/pytorch/submission.py | 5 +++-- reference_algorithms/paper_baselines/sam/jax/submission.py | 5 +++-- .../paper_baselines/sam/pytorch/submission.py | 5 +++-- .../paper_baselines/shampoo/jax/submission.py | 7 ++++--- .../target_setting_algorithms/jax_submission_base.py | 5 +++-- .../target_setting_algorithms/pytorch_submission_base.py | 5 +++-- submissions/template/submission.py | 7 ++++--- 30 files changed, 107 insertions(+), 77 deletions(-) diff --git a/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py b/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py index 63cf25fe5..b390639f3 100644 --- a/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py +++ b/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py @@ -260,10 +260,11 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, - train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None + ) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py b/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py index ab0ee82b1..88725d5c3 100644 --- a/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py +++ b/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py @@ -260,10 +260,11 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, - train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None + ) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py b/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py index 72a3bf289..3fc054984 100644 --- a/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py +++ b/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py @@ -1,7 +1,7 @@ """Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch.""" import math -from typing import Any, Dict, Iterator, List, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple from absl import logging import torch @@ -232,10 +232,11 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, - train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None + ) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py b/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py index 934538b63..f218184d7 100644 --- a/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py +++ b/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py @@ -1,7 +1,7 @@ """Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch.""" import math -from typing import Any, Dict, Iterator, List, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple from absl import logging import torch @@ -232,10 +232,11 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, - train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None + ) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py b/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py index f6ada3c8e..14bca5730 100644 --- a/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py +++ b/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py @@ -272,10 +272,11 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, - train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None + ) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py b/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py index 9c7f66c43..4e1e523a2 100644 --- a/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py +++ b/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py @@ -272,10 +272,11 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, - train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None + ) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py b/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py index f968d4abf..076658093 100644 --- a/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py +++ b/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py @@ -1,7 +1,7 @@ """Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch.""" import math -from typing import Any, Dict, Iterator, List, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple from absl import logging import torch @@ -244,10 +244,11 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, - train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None + ) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py b/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py index 14c22141c..d9dde586e 100644 --- a/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py +++ b/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py @@ -1,7 +1,7 @@ """Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch.""" import math -from typing import Any, Dict, Iterator, List, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple from absl import logging import torch @@ -244,10 +244,11 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, - train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None + ) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py b/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py index 7e41e9fd7..abb598fd4 100644 --- a/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py +++ b/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py @@ -1,7 +1,7 @@ """Training algorithm track submission functions for CIFAR10.""" import functools -from typing import Any, Dict, Iterator, List, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple from flax import jax_utils import jax @@ -118,10 +118,11 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, - train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None + ) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/development_algorithms/cifar/cifar_pytorch/submission.py b/reference_algorithms/development_algorithms/cifar/cifar_pytorch/submission.py index 81110bae6..def94296b 100644 --- a/reference_algorithms/development_algorithms/cifar/cifar_pytorch/submission.py +++ b/reference_algorithms/development_algorithms/cifar/cifar_pytorch/submission.py @@ -1,6 +1,6 @@ """Training algorithm track submission functions for CIFAR10.""" -from typing import Any, Dict, Iterator, List, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple import torch from torch.optim.lr_scheduler import CosineAnnealingLR @@ -61,10 +61,11 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, - train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None + ) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params).""" del current_params_types del hyperparameters diff --git a/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py b/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py index 3f75c9904..4fd7d2212 100644 --- a/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py +++ b/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py @@ -1,7 +1,7 @@ """Training algorithm track submission functions for MNIST.""" import functools -from typing import Any, Dict, Iterator, List, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple from flax import jax_utils import jax @@ -83,10 +83,11 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, - train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None + ) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/development_algorithms/mnist/mnist_pytorch/submission.py b/reference_algorithms/development_algorithms/mnist/mnist_pytorch/submission.py index d326f4035..c14de49ab 100644 --- a/reference_algorithms/development_algorithms/mnist/mnist_pytorch/submission.py +++ b/reference_algorithms/development_algorithms/mnist/mnist_pytorch/submission.py @@ -1,6 +1,6 @@ """Training algorithm track submission functions for MNIST.""" -from typing import Any, Dict, Iterator, List, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple import torch @@ -40,10 +40,11 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, - train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None + ) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params).""" del hyperparameters del loss_type diff --git a/reference_algorithms/paper_baselines/adafactor/jax/submission.py b/reference_algorithms/paper_baselines/adafactor/jax/submission.py index 39cf3d4f9..ce4bfebb0 100644 --- a/reference_algorithms/paper_baselines/adafactor/jax/submission.py +++ b/reference_algorithms/paper_baselines/adafactor/jax/submission.py @@ -1,7 +1,7 @@ """Submission file for an Adafactor optimizer with warmup+cosine LR in Jax.""" import functools -from typing import Any, Dict, Iterator, List, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple from flax import jax_utils import jax @@ -118,10 +118,11 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, - train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None + ) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/paper_baselines/adafactor/pytorch/submission.py b/reference_algorithms/paper_baselines/adafactor/pytorch/submission.py index 880f9168d..17c5d8a03 100644 --- a/reference_algorithms/paper_baselines/adafactor/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/adafactor/pytorch/submission.py @@ -1,7 +1,7 @@ """Submission file for Adafactor in PyTorch.""" from functools import partial -from typing import Any, Dict, Iterator, List, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple from absl import logging import torch @@ -198,10 +198,11 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, - train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None + ) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/paper_baselines/adamw/jax/submission.py b/reference_algorithms/paper_baselines/adamw/jax/submission.py index 06eeacb39..793a3f1de 100644 --- a/reference_algorithms/paper_baselines/adamw/jax/submission.py +++ b/reference_algorithms/paper_baselines/adamw/jax/submission.py @@ -1,7 +1,7 @@ """Submission file for an AdamW optimizer with warmup+cosine LR in Jax.""" import functools -from typing import Any, Dict, Iterator, List, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple from flax import jax_utils import jax @@ -118,10 +118,11 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, - train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None + ) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/paper_baselines/adamw/pytorch/submission.py b/reference_algorithms/paper_baselines/adamw/pytorch/submission.py index 0710fb9a0..225924b98 100644 --- a/reference_algorithms/paper_baselines/adamw/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/adamw/pytorch/submission.py @@ -1,6 +1,6 @@ """Submission file for an AdamW optimizer with warmup+cosine LR in PyTorch.""" -from typing import Any, Dict, Iterator, List, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple from absl import logging import torch @@ -59,10 +59,11 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, - train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None + ) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/paper_baselines/lamb/jax/submission.py b/reference_algorithms/paper_baselines/lamb/jax/submission.py index 891da63be..63b0cb219 100644 --- a/reference_algorithms/paper_baselines/lamb/jax/submission.py +++ b/reference_algorithms/paper_baselines/lamb/jax/submission.py @@ -1,7 +1,7 @@ """Submission file for a LAMB optimizer with warmup+cosine LR in Jax.""" import functools -from typing import Any, Dict, Iterator, List, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple from flax import jax_utils import jax @@ -126,10 +126,11 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, - train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None + ) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/paper_baselines/lamb/pytorch/submission.py b/reference_algorithms/paper_baselines/lamb/pytorch/submission.py index 7886dc75d..7c545d7ab 100644 --- a/reference_algorithms/paper_baselines/lamb/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/lamb/pytorch/submission.py @@ -1,7 +1,7 @@ """Submission file for a LAMB optimizer with warmup+cosine LR in PyTorch.""" import math -from typing import Any, Dict, Iterator, List, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple from absl import logging import torch @@ -197,10 +197,11 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, - train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None + ) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/paper_baselines/momentum/jax/submission.py b/reference_algorithms/paper_baselines/momentum/jax/submission.py index dc101896b..b173ba8ba 100644 --- a/reference_algorithms/paper_baselines/momentum/jax/submission.py +++ b/reference_algorithms/paper_baselines/momentum/jax/submission.py @@ -152,10 +152,11 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, - train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None + ) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/paper_baselines/momentum/pytorch/submission.py b/reference_algorithms/paper_baselines/momentum/pytorch/submission.py index 52aba82bf..c063f0a64 100644 --- a/reference_algorithms/paper_baselines/momentum/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/momentum/pytorch/submission.py @@ -75,10 +75,11 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, - train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None + ) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/paper_baselines/nadamw/jax/submission.py b/reference_algorithms/paper_baselines/nadamw/jax/submission.py index 63cf25fe5..b390639f3 100644 --- a/reference_algorithms/paper_baselines/nadamw/jax/submission.py +++ b/reference_algorithms/paper_baselines/nadamw/jax/submission.py @@ -260,10 +260,11 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, - train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None + ) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/paper_baselines/nadamw/pytorch/submission.py b/reference_algorithms/paper_baselines/nadamw/pytorch/submission.py index 72a3bf289..3fc054984 100644 --- a/reference_algorithms/paper_baselines/nadamw/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/nadamw/pytorch/submission.py @@ -1,7 +1,7 @@ """Submission file for an NAdamW optimizer with warmup+cosine LR in PyTorch.""" import math -from typing import Any, Dict, Iterator, List, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple from absl import logging import torch @@ -232,10 +232,11 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, - train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None + ) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/paper_baselines/nesterov/jax/submission.py b/reference_algorithms/paper_baselines/nesterov/jax/submission.py index e47c7fa0c..35ef2bfa8 100644 --- a/reference_algorithms/paper_baselines/nesterov/jax/submission.py +++ b/reference_algorithms/paper_baselines/nesterov/jax/submission.py @@ -152,10 +152,11 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, - train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None + ) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py b/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py index 442949866..0b7cc570b 100644 --- a/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py @@ -75,10 +75,11 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, - train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None + ) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/paper_baselines/sam/jax/submission.py b/reference_algorithms/paper_baselines/sam/jax/submission.py index 95bea68aa..da2208519 100644 --- a/reference_algorithms/paper_baselines/sam/jax/submission.py +++ b/reference_algorithms/paper_baselines/sam/jax/submission.py @@ -205,10 +205,11 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, - train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None + ) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/paper_baselines/sam/pytorch/submission.py b/reference_algorithms/paper_baselines/sam/pytorch/submission.py index 15b6b6858..a793673f9 100644 --- a/reference_algorithms/paper_baselines/sam/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/sam/pytorch/submission.py @@ -139,10 +139,11 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, - train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None + ) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/paper_baselines/shampoo/jax/submission.py b/reference_algorithms/paper_baselines/shampoo/jax/submission.py index e853a821b..504dff0d1 100644 --- a/reference_algorithms/paper_baselines/shampoo/jax/submission.py +++ b/reference_algorithms/paper_baselines/shampoo/jax/submission.py @@ -1,7 +1,7 @@ """Submission file for a Shampoo optimizer with warmup+cosine LR in Jax.""" import functools -from typing import Any, Dict, Iterator, List, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple from flax import jax_utils import jax @@ -121,10 +121,11 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, - train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None + ) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/target_setting_algorithms/jax_submission_base.py b/reference_algorithms/target_setting_algorithms/jax_submission_base.py index a98d134fc..999422fb0 100644 --- a/reference_algorithms/target_setting_algorithms/jax_submission_base.py +++ b/reference_algorithms/target_setting_algorithms/jax_submission_base.py @@ -77,10 +77,11 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, - train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None + ) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py b/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py index 586429e37..92f222a18 100644 --- a/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py +++ b/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py @@ -20,10 +20,11 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, - train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None + ) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/submissions/template/submission.py b/submissions/template/submission.py index 9bfb23367..b8a394322 100644 --- a/submissions/template/submission.py +++ b/submissions/template/submission.py @@ -4,7 +4,7 @@ and https://github.com/mlcommons/algorithmic-efficiency/blob/main/DOCUMENTATION.md#disallowed-submissions for guidelines. """ -from typing import Any, Dict, Iterator, List, Tuple +from typing import Any, Dict, Iterator, List, Optional, Tuple from algorithmic_efficiency import spec @@ -30,10 +30,11 @@ def update_params(workload: spec.Workload, batch: Dict[str, spec.Tensor], loss_type: spec.LossType, optimizer_state: spec.OptimizerState, - train_state: Dict[str, Any], eval_results: List[Tuple[int, float]], global_step: int, - rng: spec.RandomState) -> spec.UpdateReturn: + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None + ) -> spec.UpdateReturn: """ Returns: (new_optimizer_state, update_fn) From 86114ef970c832ea5b8ed15c47856e6d6d325df3 Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Fri, 25 Oct 2024 11:47:47 +0200 Subject: [PATCH 20/22] fix missing import Optional --- reference_algorithms/paper_baselines/momentum/jax/submission.py | 2 +- .../paper_baselines/momentum/pytorch/submission.py | 2 +- reference_algorithms/paper_baselines/nesterov/jax/submission.py | 2 +- .../paper_baselines/nesterov/pytorch/submission.py | 2 +- reference_algorithms/paper_baselines/sam/pytorch/submission.py | 2 +- .../target_setting_algorithms/jax_submission_base.py | 2 +- .../target_setting_algorithms/pytorch_submission_base.py | 2 +- 7 files changed, 7 insertions(+), 7 deletions(-) diff --git a/reference_algorithms/paper_baselines/momentum/jax/submission.py b/reference_algorithms/paper_baselines/momentum/jax/submission.py index b173ba8ba..346abe652 100644 --- a/reference_algorithms/paper_baselines/momentum/jax/submission.py +++ b/reference_algorithms/paper_baselines/momentum/jax/submission.py @@ -1,7 +1,7 @@ """Submission file for a SGD with HeavyBall momentum optimizer in Jax.""" import functools -from typing import Any, Callable, Dict, Iterator, List, Tuple +from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple from flax import jax_utils import jax diff --git a/reference_algorithms/paper_baselines/momentum/pytorch/submission.py b/reference_algorithms/paper_baselines/momentum/pytorch/submission.py index c063f0a64..090a8bc01 100644 --- a/reference_algorithms/paper_baselines/momentum/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/momentum/pytorch/submission.py @@ -1,6 +1,6 @@ """Submission file for a SGD with HeavyBall momentum optimizer in PyTorch.""" -from typing import Any, Callable, Dict, Iterator, List, Tuple +from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple from absl import logging import optax diff --git a/reference_algorithms/paper_baselines/nesterov/jax/submission.py b/reference_algorithms/paper_baselines/nesterov/jax/submission.py index 35ef2bfa8..fa5329778 100644 --- a/reference_algorithms/paper_baselines/nesterov/jax/submission.py +++ b/reference_algorithms/paper_baselines/nesterov/jax/submission.py @@ -1,7 +1,7 @@ """Submission file for a SGD with Nesterov momentum optimizer in Jax.""" import functools -from typing import Any, Callable, Dict, Iterator, List, Tuple +from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple from flax import jax_utils import jax diff --git a/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py b/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py index 0b7cc570b..ce0854f7d 100644 --- a/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py @@ -1,6 +1,6 @@ """Submission file for a SGD with Nesterov momentum optimizer in PyTorch.""" -from typing import Any, Callable, Dict, Iterator, List, Tuple +from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple from absl import logging import optax diff --git a/reference_algorithms/paper_baselines/sam/pytorch/submission.py b/reference_algorithms/paper_baselines/sam/pytorch/submission.py index a793673f9..e9c9c9bc4 100644 --- a/reference_algorithms/paper_baselines/sam/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/sam/pytorch/submission.py @@ -1,6 +1,6 @@ """Submission file for a SAM optimizer with warmup+cosine LR in PyTorch.""" -from typing import Any, Callable, Dict, Iterator, List, Tuple +from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple from absl import logging import torch diff --git a/reference_algorithms/target_setting_algorithms/jax_submission_base.py b/reference_algorithms/target_setting_algorithms/jax_submission_base.py index 999422fb0..6914da94e 100644 --- a/reference_algorithms/target_setting_algorithms/jax_submission_base.py +++ b/reference_algorithms/target_setting_algorithms/jax_submission_base.py @@ -1,6 +1,6 @@ """Update submission function in Jax.""" import functools -from typing import Any, Dict, List, Tuple +from typing import Any, Dict, List, Optional, Tuple import jax from jax import lax diff --git a/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py b/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py index 92f222a18..606253e32 100644 --- a/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py +++ b/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py @@ -1,6 +1,6 @@ """Batch size and update submission functions in PyTorch.""" -from typing import Any, Dict, List, Tuple +from typing import Any, Dict, List, Optional, Tuple from absl import logging import torch From 1965241b5bc1995206569ca7f786f8f8e098a7ed Mon Sep 17 00:00:00 2001 From: Niccolo-Ajroldi Date: Fri, 25 Oct 2024 11:52:50 +0200 Subject: [PATCH 21/22] fix yapf --- .../external_tuning/jax_nadamw_full_budget.py | 26 +++++++++---------- .../jax_nadamw_target_setting.py | 26 +++++++++---------- .../pytorch_nadamw_full_budget.py | 26 +++++++++---------- .../pytorch_nadamw_target_setting.py | 26 +++++++++---------- .../self_tuning/jax_nadamw_full_budget.py | 26 +++++++++---------- .../self_tuning/jax_nadamw_target_setting.py | 26 +++++++++---------- .../self_tuning/pytorch_nadamw_full_budget.py | 26 +++++++++---------- .../pytorch_nadamw_target_setting.py | 26 +++++++++---------- .../cifar/cifar_jax/submission.py | 26 +++++++++---------- .../cifar/cifar_pytorch/submission.py | 26 +++++++++---------- .../mnist/mnist_jax/submission.py | 26 +++++++++---------- .../mnist/mnist_pytorch/submission.py | 26 +++++++++---------- .../adafactor/jax/submission.py | 26 +++++++++---------- .../adafactor/pytorch/submission.py | 26 +++++++++---------- .../paper_baselines/adamw/jax/submission.py | 26 +++++++++---------- .../adamw/pytorch/submission.py | 26 +++++++++---------- .../paper_baselines/lamb/jax/submission.py | 26 +++++++++---------- .../lamb/pytorch/submission.py | 26 +++++++++---------- .../momentum/jax/submission.py | 26 +++++++++---------- .../momentum/pytorch/submission.py | 26 +++++++++---------- .../paper_baselines/nadamw/jax/submission.py | 26 +++++++++---------- .../nadamw/pytorch/submission.py | 26 +++++++++---------- .../nesterov/jax/submission.py | 26 +++++++++---------- .../nesterov/pytorch/submission.py | 26 +++++++++---------- .../paper_baselines/sam/jax/submission.py | 26 +++++++++---------- .../paper_baselines/sam/pytorch/submission.py | 26 +++++++++---------- .../paper_baselines/shampoo/jax/submission.py | 26 +++++++++---------- .../jax_submission_base.py | 26 +++++++++---------- .../pytorch_submission_base.py | 26 +++++++++---------- submissions/template/submission.py | 26 +++++++++---------- 30 files changed, 390 insertions(+), 390 deletions(-) diff --git a/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py b/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py index b390639f3..a235c50cd 100644 --- a/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py +++ b/prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py @@ -252,19 +252,19 @@ def _loss_fn(params): return new_optimizer_state, updated_params, new_model_state, loss, grad_norm -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None - ) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py b/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py index 88725d5c3..06413f681 100644 --- a/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py +++ b/prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py @@ -252,19 +252,19 @@ def _loss_fn(params): return new_optimizer_state, updated_params, new_model_state, loss, grad_norm -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None - ) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py b/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py index 3fc054984..0e654d43c 100644 --- a/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py +++ b/prize_qualification_baselines/external_tuning/pytorch_nadamw_full_budget.py @@ -224,19 +224,19 @@ def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): return optimizer_state -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None - ) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py b/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py index f218184d7..dd0b8b076 100644 --- a/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py +++ b/prize_qualification_baselines/external_tuning/pytorch_nadamw_target_setting.py @@ -224,19 +224,19 @@ def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): return optimizer_state -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None - ) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py b/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py index 14bca5730..a9f048f03 100644 --- a/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py +++ b/prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py @@ -264,19 +264,19 @@ def _loss_fn(params): return new_optimizer_state, updated_params, new_model_state, loss, grad_norm -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None - ) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py b/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py index 4e1e523a2..4d3d2b341 100644 --- a/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py +++ b/prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py @@ -264,19 +264,19 @@ def _loss_fn(params): return new_optimizer_state, updated_params, new_model_state, loss, grad_norm -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None - ) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py b/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py index 076658093..5a5319957 100644 --- a/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py +++ b/prize_qualification_baselines/self_tuning/pytorch_nadamw_full_budget.py @@ -236,19 +236,19 @@ def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): return optimizer_state -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None - ) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py b/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py index d9dde586e..699b11268 100644 --- a/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py +++ b/prize_qualification_baselines/self_tuning/pytorch_nadamw_target_setting.py @@ -236,19 +236,19 @@ def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): return optimizer_state -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None - ) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py b/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py index abb598fd4..97d6df9f1 100644 --- a/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py +++ b/reference_algorithms/development_algorithms/cifar/cifar_jax/submission.py @@ -110,19 +110,19 @@ def _loss_fn(params): # Not allowed to update the model parameters, hyperparameters, global step, or # optimzier state. -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None - ) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/development_algorithms/cifar/cifar_pytorch/submission.py b/reference_algorithms/development_algorithms/cifar/cifar_pytorch/submission.py index def94296b..853064957 100644 --- a/reference_algorithms/development_algorithms/cifar/cifar_pytorch/submission.py +++ b/reference_algorithms/development_algorithms/cifar/cifar_pytorch/submission.py @@ -53,19 +53,19 @@ def init_optimizer_state(workload: spec.Workload, return optimizer_state -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None - ) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params).""" del current_params_types del hyperparameters diff --git a/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py b/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py index 4fd7d2212..6d05954a1 100644 --- a/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py +++ b/reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py @@ -75,19 +75,19 @@ def loss_fn(params): return new_optimizer_state, updated_params, new_model_state -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None - ) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/development_algorithms/mnist/mnist_pytorch/submission.py b/reference_algorithms/development_algorithms/mnist/mnist_pytorch/submission.py index c14de49ab..d27d7f742 100644 --- a/reference_algorithms/development_algorithms/mnist/mnist_pytorch/submission.py +++ b/reference_algorithms/development_algorithms/mnist/mnist_pytorch/submission.py @@ -32,19 +32,19 @@ def init_optimizer_state(workload: spec.Workload, return optimizer_state -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None - ) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params).""" del hyperparameters del loss_type diff --git a/reference_algorithms/paper_baselines/adafactor/jax/submission.py b/reference_algorithms/paper_baselines/adafactor/jax/submission.py index ce4bfebb0..efe238f26 100644 --- a/reference_algorithms/paper_baselines/adafactor/jax/submission.py +++ b/reference_algorithms/paper_baselines/adafactor/jax/submission.py @@ -110,19 +110,19 @@ def _loss_fn(params): return new_optimizer_state, updated_params, new_model_state, loss, grad_norm -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None - ) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/paper_baselines/adafactor/pytorch/submission.py b/reference_algorithms/paper_baselines/adafactor/pytorch/submission.py index 17c5d8a03..377468612 100644 --- a/reference_algorithms/paper_baselines/adafactor/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/adafactor/pytorch/submission.py @@ -190,19 +190,19 @@ def step(self, closure=None): return loss -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None - ) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/paper_baselines/adamw/jax/submission.py b/reference_algorithms/paper_baselines/adamw/jax/submission.py index 793a3f1de..31e0a6801 100644 --- a/reference_algorithms/paper_baselines/adamw/jax/submission.py +++ b/reference_algorithms/paper_baselines/adamw/jax/submission.py @@ -110,19 +110,19 @@ def _loss_fn(params): return new_optimizer_state, updated_params, new_model_state, loss, grad_norm -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None - ) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/paper_baselines/adamw/pytorch/submission.py b/reference_algorithms/paper_baselines/adamw/pytorch/submission.py index 225924b98..27ceaeef7 100644 --- a/reference_algorithms/paper_baselines/adamw/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/adamw/pytorch/submission.py @@ -51,19 +51,19 @@ def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): return optimizer_state -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None - ) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/paper_baselines/lamb/jax/submission.py b/reference_algorithms/paper_baselines/lamb/jax/submission.py index 63b0cb219..be13ab540 100644 --- a/reference_algorithms/paper_baselines/lamb/jax/submission.py +++ b/reference_algorithms/paper_baselines/lamb/jax/submission.py @@ -118,19 +118,19 @@ def _loss_fn(params): return new_optimizer_state, updated_params, new_model_state, loss, grad_norm -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None - ) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/paper_baselines/lamb/pytorch/submission.py b/reference_algorithms/paper_baselines/lamb/pytorch/submission.py index 7c545d7ab..d3b491e75 100644 --- a/reference_algorithms/paper_baselines/lamb/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/lamb/pytorch/submission.py @@ -189,19 +189,19 @@ def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): return optimizer_state -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None - ) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/paper_baselines/momentum/jax/submission.py b/reference_algorithms/paper_baselines/momentum/jax/submission.py index 346abe652..3eef23942 100644 --- a/reference_algorithms/paper_baselines/momentum/jax/submission.py +++ b/reference_algorithms/paper_baselines/momentum/jax/submission.py @@ -144,19 +144,19 @@ def _loss_fn(params): return new_optimizer_state, updated_params, new_model_state, loss, grad_norm -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None - ) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/paper_baselines/momentum/pytorch/submission.py b/reference_algorithms/paper_baselines/momentum/pytorch/submission.py index 090a8bc01..cf474ebdd 100644 --- a/reference_algorithms/paper_baselines/momentum/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/momentum/pytorch/submission.py @@ -67,19 +67,19 @@ def create_lr_schedule_fn( return lr_schedule_fn -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None - ) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/paper_baselines/nadamw/jax/submission.py b/reference_algorithms/paper_baselines/nadamw/jax/submission.py index b390639f3..a235c50cd 100644 --- a/reference_algorithms/paper_baselines/nadamw/jax/submission.py +++ b/reference_algorithms/paper_baselines/nadamw/jax/submission.py @@ -252,19 +252,19 @@ def _loss_fn(params): return new_optimizer_state, updated_params, new_model_state, loss, grad_norm -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None - ) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/paper_baselines/nadamw/pytorch/submission.py b/reference_algorithms/paper_baselines/nadamw/pytorch/submission.py index 3fc054984..0e654d43c 100644 --- a/reference_algorithms/paper_baselines/nadamw/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/nadamw/pytorch/submission.py @@ -224,19 +224,19 @@ def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): return optimizer_state -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None - ) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/paper_baselines/nesterov/jax/submission.py b/reference_algorithms/paper_baselines/nesterov/jax/submission.py index fa5329778..553b3e478 100644 --- a/reference_algorithms/paper_baselines/nesterov/jax/submission.py +++ b/reference_algorithms/paper_baselines/nesterov/jax/submission.py @@ -144,19 +144,19 @@ def _loss_fn(params): return new_optimizer_state, updated_params, new_model_state, loss, grad_norm -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None - ) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py b/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py index ce0854f7d..ba8c69e6c 100644 --- a/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/nesterov/pytorch/submission.py @@ -67,19 +67,19 @@ def create_lr_schedule_fn( return lr_schedule_fn -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None - ) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/paper_baselines/sam/jax/submission.py b/reference_algorithms/paper_baselines/sam/jax/submission.py index da2208519..b5c7069cb 100644 --- a/reference_algorithms/paper_baselines/sam/jax/submission.py +++ b/reference_algorithms/paper_baselines/sam/jax/submission.py @@ -197,19 +197,19 @@ def _loss_fn(params, update_batch_norm=True): return new_optimizer_state, updated_params, new_model_state, loss, grad_norm -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None - ) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/paper_baselines/sam/pytorch/submission.py b/reference_algorithms/paper_baselines/sam/pytorch/submission.py index e9c9c9bc4..b69945d51 100644 --- a/reference_algorithms/paper_baselines/sam/pytorch/submission.py +++ b/reference_algorithms/paper_baselines/sam/pytorch/submission.py @@ -131,19 +131,19 @@ def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): return optimizer_state -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None - ) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/paper_baselines/shampoo/jax/submission.py b/reference_algorithms/paper_baselines/shampoo/jax/submission.py index 504dff0d1..8f0b311a0 100644 --- a/reference_algorithms/paper_baselines/shampoo/jax/submission.py +++ b/reference_algorithms/paper_baselines/shampoo/jax/submission.py @@ -113,19 +113,19 @@ def _loss_fn(params): return new_optimizer_state, updated_params, new_model_state, loss, grad_norm -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None - ) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/target_setting_algorithms/jax_submission_base.py b/reference_algorithms/target_setting_algorithms/jax_submission_base.py index 6914da94e..51b20181b 100644 --- a/reference_algorithms/target_setting_algorithms/jax_submission_base.py +++ b/reference_algorithms/target_setting_algorithms/jax_submission_base.py @@ -69,19 +69,19 @@ def _loss_fn(params): return new_optimizer_state, updated_params, new_model_state, loss, grad_norm -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None - ) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py b/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py index 606253e32..6203c58b3 100644 --- a/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py +++ b/reference_algorithms/target_setting_algorithms/pytorch_submission_base.py @@ -12,19 +12,19 @@ USE_PYTORCH_DDP = pytorch_setup()[0] -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None - ) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """Return (updated_optimizer_state, updated_params, updated_model_state).""" del current_params_types del loss_type diff --git a/submissions/template/submission.py b/submissions/template/submission.py index b8a394322..ab98c9958 100644 --- a/submissions/template/submission.py +++ b/submissions/template/submission.py @@ -22,19 +22,19 @@ def init_optimizer_state(workload: spec.Workload, pass -def update_params(workload: spec.Workload, - current_param_container: spec.ParameterContainer, - current_params_types: spec.ParameterTypeTree, - model_state: spec.ModelAuxiliaryState, - hyperparameters: spec.Hyperparameters, - batch: Dict[str, spec.Tensor], - loss_type: spec.LossType, - optimizer_state: spec.OptimizerState, - eval_results: List[Tuple[int, float]], - global_step: int, - rng: spec.RandomState, - train_state: Optional[Dict[str, Any]] = None - ) -> spec.UpdateReturn: +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None) -> spec.UpdateReturn: """ Returns: (new_optimizer_state, update_fn) From f4c17f0223c2c50c8d4fbdf16cb6271c31d9c989 Mon Sep 17 00:00:00 2001 From: priyakasimbeg Date: Thu, 21 Nov 2024 09:33:38 -0800 Subject: [PATCH 22/22] Update README.md fix pytorch command --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 5a1f10a33..516c8eb1b 100644 --- a/README.md +++ b/README.md @@ -89,7 +89,7 @@ python3 submission_runner.py \ --workload=mnist \ --experiment_dir=$HOME/experiments \ --experiment_name=my_first_experiment \ - --submission_path=reference_algorithms/paper_baselines/adamw/jax/submission.py \ + --submission_path=reference_algorithms/paper_baselines/adamw/pytorch/submission.py \ --tuning_search_space=reference_algorithms/paper_baselines/adamw/tuning_search_space.json ```